27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms# 预处理
data_transform = transforms.Compose([transforms.Resize((224,224)),           # 缩放图像transforms.ToTensor(),                  # 转为Tensotransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])path =  r'C:\STEW\test'for root,dirs,files in os.walk(path):print('root',root) #遍历到该目录地址print('dirs',dirs) #遍历到该目录下的子目录名 []print('files',files)  #遍历到该目录下的文件  []def read_txt_files(path):# 创建文件名列表file_names = []# 遍历给定目录及其子目录下的所有文件for root, dirs, files in os.walk(path):# 遍历所有文件for file in files:# 如果是 .txt 文件,则加入文件名列表if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。file_names.append(os.path.join(root, file))# 返回文件名列表return file_namesclass DogCat(Dataset):      # 数据处理def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等#['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']imgs = os.listdir(root)self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件self.transforms = data_transform                        # 图像预处理def __getitem__(self, index):       # 读取图片img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0 #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.data = Image.open(img_path)if self.transforms:     # 图像预处理data = self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)dataset = DogCat('./data/',transforms=True)for img,label in dataset:print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''import os# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):files = os.listdir(file_path)text_list = []for file in files:with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:text_list.append(f.read())return text_list, filesclass ImageFolderCustom(Dataset):# 2. Initialize with a targ_dir and transform (optional) parameterdef __init__(self, targ_dir: str, transform=None) -> None:# 3. Create class attributes# Get all image pathsself.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's# Setup transformsself.transform = transform# Create classes and class_to_idx attributesself.classes, self.class_to_idx = find_classes(targ_dir)# 4. Make function to load imagesdef load_image(self, index: int) -> Image.Image:"Opens an image via a path and returns it."image_path = self.paths[index]return Image.open(image_path) # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)def __len__(self) -> int:"Returns the total number of samples."return len(self.paths)# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:"Returns one sample of data, data and label (X, y)."img = self.load_image(index)class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpegclass_idx = self.class_to_idx[class_name]# Transform if necessaryif self.transform:return self.transform(img), class_idx # return data, label (X, y)else:return img, class_idx # return data, label (X, y)import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transformsos.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"# cifar-10进行测验class Cutout(object):"""Randomly mask out one or more patches from an image.Args:n_holes (int): Number of patches to cut out of each image.length (int): The length (in pixels) of each square patch."""def __init__(self, n_holes, length):self.n_holes = n_holesself.length = lengthdef __call__(self, img):"""Args:img (Tensor): Tensor image of size (C, H, W).Returns:Tensor: Image with n_holes of dimension length x length cut out of it."""h = img.size(1)w = img.size(2)mask = np.ones((h, w), np.float32)for n in range(self.n_holes):y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img = img * maskreturn imgdef load_data_cifar10(batch_size=128,num_workers=2):# 操作合集# Data augmentationtrain_transform_1 = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转transforms.ToTensor(),transforms.Normalize((0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)),Cutout(1, 16),  # 务必放在ToTensor的后面])train_transform_2 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])test_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])# 训练集1trainset1 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_1,)# 训练集2trainset2 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_2,)# 测试集testset = tv.datasets.CIFAR10(root='data',train=False,download=False,transform=test_transform,)# 训练数据加载器1trainloader1 = torch.utils.data.DataLoader(trainset1,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 训练数据加载器2trainloader2 = torch.utils.data.DataLoader(trainset2,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 测试数据加载器testloader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))return trainloader1,trainloader2,testloaderdef main():start = time.time()batch_size = 128cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)model = resnet50().cuda()# model.load_state_dict(torch.load('_ResNet50.pth'))# 存在已保存的参数文件# model = nn.DataParallel(model,device_ids=[0,])  # 又套一层model = nn.DataParallel(model,device_ids=[0,1,2])loss = nn.CrossEntropyLoss().cuda()optimizer = optim.Adam(model.parameters(),lr=0.001)for epoch in range(50):model.train()  # 训练时务必写loss_=0.0num=0.0# train on trainloader1(data augmentation) and trainloader2for i,data in enumerate(cifar_train1,0):x, label = datax, label = x.cuda(),label.cuda()# xp = model(x) #outputl = loss(p,label) #lossoptimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num+=1for i, data in enumerate(cifar_train2, 0):x, label = datax, label = x.cuda(), label.cuda()# xp = model(x)l = loss(p, label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num += 1model.eval()  # 评估时务必写print("loss:",float(loss_)/num)# test on trainloader2,testloaderwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_train2:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_1 = total_correct / total_num# Testwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x) #output# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_2 = total_correct / total_numprint(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)# 保存时只保存model.moduletorch.save(model.module.state_dict(),'resnet50.pth')print("The interval is :",time.time() - start)if __name__ == '__main__':main()

3、对你有帮助的话,给个关注吧~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/222326.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SpringBoot】之Security进阶使用

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是君易--鑨,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的博客专栏《SpringBoot开发之Security系列》。&#x1f3af…

变分自动编码器【03/3】:使用 Docker 和 Bash 脚本进行超参数调整

一、说明 在深入研究第 1 部分中的介绍和实现,并在第 2 部分中探索训练过程之后,我们现在将重点转向在第 3 部分中通过超参数调整来优化模型的性能。要访问本系列的完整代码,请访问我们的 GitHub 存储库在GitHub - asokraju/ImageAutoEncoder…

直播电商“去网红化”势在必行,AI数字人打造品牌专属IP

近年来,网红直播带货“翻车”事件频发,给品牌商带来了信任危机和负面口碑的困扰,严重损害了企业的声誉。这证明强大的个人IP,对于吸引粉丝和流量确实能起到巨大的好处,堪称“金牌销售”,但太过强势的个人IP属性也会给企业带来一定风险&#x…

计算机网络:应用层

0 本节主要内容 问题描述 解决思路 1 问题描述 不同的网络服务: DNS:用来把人们使用的机器名字(域名)转换为 IP 地址;DHCP:允许一台计算机加入网络和获取 IP 地址,而不用手工配置&#xff1…

回顾丨2023 SpeechHome 第三届语音技术研讨会

下面是整体会议的内容回顾: 18日线上直播回顾 18日上午9:30,AISHELL & SpeechHome CEO卜辉宣布研讨会开始,并简要介绍本次研讨会的筹备情况以及报告内容。随后,CCF语音对话与听觉专委会副主任、清华大学教授郑方&#xff0c…

Spring AOP入门指南:轻松掌握面向切面编程的基础知识

面向切面编程 1,AOP简介1.1 什么是AOP?1.2 AOP作用1.3 AOP核心概念 2,AOP入门案例2.1 需求分析2.2 思路分析2.3 环境准备2.4 AOP实现步骤步骤1:添加依赖步骤2:定义接口与实现类步骤3:定义通知类和通知步骤4:定义切入点步骤5:制作切面步骤6:将通知类配给…

7-1 建立二叉搜索树并查找父结点(PTA - 数据结构)

按输入顺序建立二叉搜索树,并搜索某一结点,输出其父结点。 输入格式: 输入有三行: 第一行是n值,表示有n个结点; 第二行有n个整数,分别代表n个结点的数据值; 第三行是x,表示要搜索值…

【已解决】修改了网站的class样式name值,会影响SEO,搜索引擎抓取网站及排名吗?

问题: 修改了网站的class样式name值,会影响搜索引擎抓取网站及排名吗? 解答: 如果你仅仅修改了网站class样式的名称,而没有改变网站的结构和内容,那么搜索引擎通常不会因此而影响它对网站的抓取和排名。但…

【C++入门到精通】互斥锁 (Mutex) C++11 [ C++入门 ]

阅读导航 引言一、Mutex的简介二、Mutex的种类1. std::mutex (基本互斥锁)2. std::recursive_mutex (递归互斥锁)3. std::timed_mutex (限时等待互斥锁)4. std::recursive_timed_mutex (限时等待…

使用VSC从零开始Vue.js——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务3:数据可视化

使用Visual Studio Code(VSC)进行Vue开发非常方便,下面是一些基本步骤: 一、下载和安装Vue 官网下载地址Download | Node.js Vue.js是基于Node.js的,所以首先需要安装Node.js,官网下载地址:No…

PSP - 结构生物学中的机器学习 (NIPS MLSB Workshop 2023.12)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/135120094 Machine Learning in Structural Biology (机器学习在结构生物学中) 网址:https://www.mlsb.io/ Workshop at the 37th Co…

应用案例 | 汽车行业基于3D机器视觉引导机器人上下料解决方案

Part.1 背景 近年来,汽车行业蓬勃发展,一度出现供不应求的现象。在汽车零配件、整车大规模制造的过程中,为了降本增效,提升产品质量,工厂急需完成自动化升级。随着人工智能的发展,越来越多的生产环节引入机…

Jupyter Notebook修改默认工作目录

1、参考修改Jupyter Notebook的默认工作目录_jupyter文件路径-CSDN博客修改配置文件 2.在上述博客内容的基础上,这里不是删除【%USERPROFILE%】而是把这个地方替换为所要设置的工作目录路径, 3.【起始位置】也可以更改为所要设置的工作目录路径&#x…

【JVM】一、认识JVM

文章目录 1、虚拟机2、Java虚拟机3、JVM的整体结构4、Java代码的执行流程5、JVM的分类6、JVM的生命周期 1、虚拟机 虚拟机,Virtual Machine,一台虚拟的计算机,用来执行虚拟计算机指令。分为: 系统虚拟机:如VMware&am…

我的创作纪念日——成为创作者第1024天

机缘 一、前言 早上收到CSDN的推送信息,今天是我成为创作者的第1024天,回想起自己已经好久没有写博客了,突然间很有感触,想水一篇文章,跟小伙伴们分享一下我的经历。 二、自我介绍 我出生在广东潮汕地区的一个小城…

TypeScript实战——ChatGPT前端自适应手机端,PC端

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 可以在线体验哦:体验地址 文章目录 前言引言先看效果PC端手机端 实现原理解释 包的架构目录 引言 ChatGPT是由OpenAI开发的一种基于语言模型的对话系统。它是GPT(…

ES排错命令

GET _cat/indices?v&healthred GET _cat/indices?v&healthyellow GET _cat/indices?v&healthgreen确定哪些索引有问题,多少索引有问题。_cat API 可以通过返回结果告诉我们这一点 查看有问题的分片以及原因。 这与索引列表有关,但是索引…

【介质】详解NVMe SSD存储性能影响因素

导读: NVMe SSD的性能时常捉摸不定,为此我们需要打开SSD的神秘盒子,从各个视角分析SSD性能影响因素,并思考从存储软件的角度如何最优化使用NVMe SSD,推进数据中心闪存化进程。本文从NVMe SSD的性能影响因素进行分析&am…

Java智慧工地源码 SAAS智慧工地源码 智慧工地管理可视化平台源码 带移动APP

一、系统主要功能介绍 系统功能介绍: 【项目人员管理】 1. 项目管理:项目名称、施工单位名称、项目地址、项目地址、总造价、总面积、施工准可证、开工日期、计划竣工日期、项目状态等。 2. 人员信息管理:支持身份证及人脸信息采集&#…

PLC物联网,实现工厂设备数据采集

随着工业4.0时代的到来,物联网技术在工厂设备管理领域的应用日益普及。作为物联网技术的重要一环,PLC物联网为工厂设备数据采集带来了前所未有的便捷和高效。本文将围绕“PLC物联网,实现工厂设备数据采集”这一主题,探讨PLC物联网…