[深度学习]图像分类项目-食物分类

图像分类项目-食物分类(监督学习和半监督学习)

文章目录

  • 图像分类项目-食物分类(监督学习和半监督学习)
    • 项目介绍
    • 数据处理
      • 设定随机种子
      • 读取文件内容
      • 图像增广
      • 定义Dataset类
    • 模型定义
      • 迁移学习
    • 定义超参
      • Adam和AdamW
    • 训练过程
    • 半监督学习
      • 定义Dataset类
      • 模型定义
      • 定义超参
      • 训练过程

项目介绍

image-20250214102822207

数据处理

设定随机种子

由于神经网络的训练具有随机性,为了保证之前得到的好的训练效果可以得到复现,设定随机种子,让训练过程中的随机行为每次训练都是相同。

def seed_everything(seed):torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = Truerandom.seed(seed)np.random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)
#################################################################
seed_everything(0)
###############################################

读取文件内容

进行数据处理前,需要了解数据的形式训练集中,有标签的数据按照11类分别存放在11个文件夹中,因此要循环依次读取这11个文件夹的内容:

image-20250214104855956

首先从文件夹中读出每张图片和对应标签(读取的是带标签的数据):

HW = 224def read_file(path):for i in tqdm(range(11)):file_dir = path + "/%02d" % ifile_list = os.listdir(file_dir)  # 列出文件夹下所有文件名字xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)  # 每个元素存一个图片,图片为整形类型yi = np.zeros(len(file_list))for j, img_name in enumerate(file_list):img_path = os.path.join(file_dir, img_name)  # 拼接地址img = Image.open(img_path)  # 打开图片img = img.resize((HW, HW))  # 修改成模型接受的大小xi[j, ...] = imgyi[j] = iif i == 0:#第一个数据赋值X = xiY = yielse:#后续数据尾插X = np.concatenate((X, xi), axis=0)Y = np.concatenate((Y, yi), axis=0)print("读到了%d个数据" % len(Y))return X, Y

图像增广

模型对训练使用的图片数据有好的效果,但是如果对图片数据进行一定的变化,模型的效果就变差,因此在训练时,不仅使用原图片训练,还要对图片进行旋转,放大裁切等图像操作,将原图片和操作后的图片都作为训练数据,也就是图像增广,让模型的效果更好。

image-20250221112528853

train_transform = transforms.Compose(#定义训练集增广方式[transforms.ToPILImage(), #224,224,3模型:3,224,244transforms.RandomResizedCrop(224),#随机放大裁切transforms.RandomRotation(50),#50度以内随机旋转transforms.ToTensor()#模型运行的数据类型为张量]
)val_transform = transforms.Compose(#验证集不需要增广[transforms.ToPILImage(),  # 224,224,3模型:3,224,244transforms.ToTensor()  # 模型运行的数据类型为张量]
)

定义Dataset类

class food_Dataset(Dataset):#继承Dateset类def __init__(self, path, mode="train"):self.X, self.Y = read_file(path)self.Y = torch.LongTensor(self.Y)#图片数据类型为整形if mode == "train":#根据模式选择增广类型self.transform = train_transformelse:self.transform = val_transformdef __getitem__(self, item):return self.transform(self.X[item]), self.Y[item] #使用图片增广def __len__(self):return len(self.Y)

模型定义

在模型中设定一些卷积、归一化、池化、激活函数对数据进行特征提取。

class myModel(nn.Module):def __init__(self, num_class):super(myModel, self).__init__()#3*224*224->512*7*7->拉直->全连接分类self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) #3厚度,64个卷积核,卷积核大小3,padding为1,步长为1 64*224*224self.bn1 = nn.BatchNorm2d(64)#归一化self.relu = nn.ReLU()self.pool1 = nn.MaxPool2d(2)   #64*112*112self.layer1 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1),    # 128*112*112nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2)   #128*56*56)self.layer2 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2)   #256*28*28)self.layer3 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2)   #512*14*14)self.pool2 = nn.MaxPool2d(2)    #512*7*7self.fc1 = nn.Linear(25088, 1000)   #25088->1000self.relu2 = nn.ReLU()self.fc2 = nn.Linear(1000, num_class)  #1000-11def forward(self, x):#使用定义的模型进行前向过程x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.pool1(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.pool2(x)x = x.view(x.size()[0], -1) #拉直,x.size(0) 就是批量大小(batch_size),表示你有多少个样本输入到模型中。-1 是自动计算剩下的维度以便将数据展平。x = self.fc1(x)x = self.relu2(x)x = self.fc2(x)return x

迁移学习

良好的模型是需要大量的数据训练得到的,由于我们设备加上数据量的限制训练出来的模型效果不会特别好,甚至预测效果接近随机预测,因此我们要进行迁移学习。简单来说,**迁移学习就是使用大佬们用大量数据训练出来的现成模型,**由于大佬的模型经过训练后有很好的特征提取效果,因此我们只需要使用大佬的模型然后加上分类头作为训练的模型即可。

image-20250221105346896

预训练是指在无关当前任务的模型训练,迁移学习使用的模型就是进行过预训练的模型,迁移学习时可以进行线性探测微调,线性探测就是在训练中不进行参数的调整,完全信任迁移学习使用的模型,微调就是在训练过程中会进行参数调整。

迁移学习时可选择只使用架构和使用架构和参数,虽然迁移学习使用的架构很优秀但是参数是更加重要的部分,因此使用架构和参数的效果要更好,要想使用迁移学习的预训练参数要保持架构一致。

from torchvision.models import resnet18#导入模型
model = resnet18(pretrained=True)#使用架构和参数
in_fetures = model.fc.in_features#获取模型的特征提取后的输出维度
model.fc = nn.Linear(in_fetures, 11)#全连接分类头

定义超参

定义学习率、损失函数、优化器、训练轮次等超参数。

Adam和AdamW

Adam优化器不仅考虑当前点的梯度还考虑之前的梯度,并且会自动更改学习率,由于参数更改时要减去学习率×梯度,当这个值过大时,Adam会自动更改学习率,AdamW是在Adam的基础上增加了权重衰减使得模型曲线更加平滑。

训练过程

def train_val(model, train_loader, val_loader, device, epochs, optimizer, loss, save_path):model = model.to(device)plt_train_loss = [] #记录所有轮次的LOSSplt_val_loss = []plt_train_acc = [] #记录准确率plt_val_acc = []max_acc = 0.0for epoch in range(epochs): #开始训练train_loss = 0.0val_loss = 0.0train_acc = 0.0#用准确率表示模型效果val_acc = 0.0start_time = time.time()model.train() #模型调为训练模式,有时训练模式和测试模式的模型不同for batch_x, batch_y in train_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x)train_bat_loss = loss(pred, target) #获取一批数据的LOSStrain_bat_loss.backward() #梯度回传optimizer.step()#更新模型optimizer.zero_grad()train_loss += train_bat_loss.cpu().item() #将gpu上的张量数据放到cpu上取出数据计算,累加记录本轮LOSStrain_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())#记录预测对的数量plt_train_loss.append(train_loss / train_loader.__len__()) #除以轮次数,得到每个轮次的LOSS平均值plt_train_acc.append(train_acc/train_loader.dataset.__len__()) #记录准确率,model.eval()#调为验证模式with torch.no_grad():#所有模型中的张量计算都积攒梯度,而验证时不需要梯度for batch_x, batch_y in val_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x)val_bat_loss = loss(pred, target)val_loss += val_bat_loss.cpu().item()val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())  # 记录预测对的数量plt_val_loss.append(val_loss / val_loader.dataset.__len__())plt_val_acc.append(val_acc/val_loader.dataset.__len__()) #记录准确率,if val_acc > max_acc: #如果当前模型效果更好,进行记录torch.save(model, save_path)max_acc = val_loss#训练效果打印print('[%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f' % \(epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1],plt_val_acc[-1]))  # 打印训练结果。 注意python语法, %2.2f 表示小数位为2的浮点数, 后面可以对应。plt.plot(plt_train_loss)plt.plot(plt_val_loss)plt.title("loss")plt.legend(["train", "val"])plt.show()plt.plot(plt_train_acc)plt.plot(plt_val_acc)plt.title("loss")plt.legend(["train", "val"])plt.show()

半监督学习

监督学习是指每个训练样本都有对应的标签,模型通过学习这些标注数据训练,目标是让模型能够根据新的输入数据预测正确的标签。

半监督学习是介于监督学习和无监督学习之间的一种方法。在半监督学习中,训练数据包含大量的未标注数据和少量的标注数据。模型利用少量的标注数据来进行学习,同时也借助未标注数据来进一步提高模型的性能。

  • 模型首先使用标注数据进行训练。
  • 模型的效果达到一定程度后,用训练得到的模型对未标注数据进行预测。
  • 若预测结果结果的置信值(成功率)达到一定值后,将预测结果(伪标签)添加到训练数据集中。

image-20250221121136819

为了加入半监督学习,对监督学习的代码进行修改。

定义Dataset类

class food_Dataset(Dataset):def __init__(self, path, mode="train"):self.mode = modeif mode == "semi":#若为半监督模式,数据只有X,没有标签Yself.X = self.read_file(path)else:self.X, self.Y = self.read_file(path)self.Y = torch.LongTensor(self.Y)  #标签转为长整形if mode == "train":#训练模式需要图片增广等操作self.transform = train_transformelse:#非训练模式,包括半监督模式下,只需要让数据转换成符合模型输入的格式即可self.transform = val_transformdef read_file(self, path):#读取数据函数if self.mode == "semi":file_list = os.listdir(path)xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)# 列出文件夹下所有文件名字for j, img_name in enumerate(file_list):img_path = os.path.join(path, img_name)img = Image.open(img_path)img = img.resize((HW, HW))xi[j, ...] = imgprint("读到了%d个数据" % len(xi))return xielse:for i in tqdm(range(11)):file_dir = path + "/%02d" % ifile_list = os.listdir(file_dir)xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)yi = np.zeros(len(file_list), dtype=np.uint8)# 列出文件夹下所有文件名字for j, img_name in enumerate(file_list):img_path = os.path.join(file_dir, img_name)img = Image.open(img_path)img = img.resize((HW, HW))xi[j, ...] = imgyi[j] = iif i == 0:X = xiY = yielse:X = np.concatenate((X, xi), axis=0)Y = np.concatenate((Y, yi), axis=0)print("读到了%d个数据" % len(Y))return X, Ydef __getitem__(self, item):if self.mode == "semi":return self.transform(self.X[item]), self.X[item]#前者为为了输入模型进行转换的X用于训练得到伪标签,后者为原始数据X用于加入半监督数据集else:return self.transform(self.X[item]), self.Y[item]def __len__(self):return len(self.X)class semiDataset(Dataset):#半监督数据集Dataset类def __init__(self, no_label_loder, model, device, thres=0.99):#传入无标签数据,预测模型,置信度x, y = self.get_label(no_label_loder, model, device, thres)if x == []:#如果预测得到的伪标签都不符合要求,如置信度低,导致半监督数据集为空进行标记self.flag = Falseelse:self.flag = Trueself.X = np.array(x)self.Y = torch.LongTensor(y)self.transform = train_transform#得到的半监督数据集同样用于模型训练def get_label(self, no_label_loder, model, device, thres):#给半监督数据打标签model = model.to(device)pred_prob = []#记录预测类型中最高概率labels = []#记录最高概率对应的标签x = []y = []soft = nn.Softmax()with torch.no_grad():#只要通过模型就会积攒梯度,只要不进行模型训练调整,积攒的梯度就没用for bat_x, _ in no_label_loder:bat_x = bat_x.to(device)pred = model(bat_x)pred_soft = soft(pred)pred_max, pred_value = pred_soft.max(1)#维度1为横向,取出最高概率和其对应的标签,由于loader中一个元素是一批数据,因此pred_max和pred_value的一个元素中包含对应批数个值pred_prob.extend(pred_max.cpu().numpy().tolist())labels.extend(pred_value.cpu().numpy().tolist())for index, prob in enumerate(pred_prob):if prob > thres:#大于置信度加入半监督数据集x.append(no_label_loder.dataset[index][1])   #调用到原始的getitem,因为要加入半监督数据集y.append(labels[index])return x, ydef __getitem__(self, item):return self.transform(self.X[item]), self.Y[item]def __len__(self):return len(self.X)def get_semi_loader(no_label_loder, model, device, thres):#获取半监督数据集semiset = semiDataset(no_label_loder, model, device, thres)if semiset.flag == False:return Noneelse:semi_loader = DataLoader(semiset, batch_size=16, shuffle=False)return semi_loader

模型定义

加入半监督学习只需要复用监督学习的训练模型进行预测即可。

定义超参

加入半监督学习要额外定义包括置信度的超参。

训练过程

def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path):model = model.to(device)semi_loader = Noneplt_train_loss = []plt_val_loss = []plt_train_acc = []plt_val_acc = []max_acc = 0.0for epoch in range(epochs):train_loss = 0.0val_loss = 0.0train_acc = 0.0val_acc = 0.0semi_loss = 0.0#半监督数据集LOSSsemi_acc = 0.0#对半监督数据集的预测准确率start_time = time.time()model.train()#训练模式for batch_x, batch_y in train_loader:#使用有标签训练集训练x, target = batch_x.to(device), batch_y.to(device)pred = model(x)train_bat_loss = loss(pred, target)train_bat_loss.backward()optimizer.step()  # 更新参数 之后要梯度清零否则会累积梯度optimizer.zero_grad()train_loss += train_bat_loss.cpu().item()train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())plt_train_loss.append(train_loss / train_loader.__len__())plt_train_acc.append(train_acc/train_loader.dataset.__len__()) #记录准确率,if semi_loader!= None:#若半监督数据集非空,使用半监督数据集进行训练for batch_x, batch_y in semi_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x)semi_bat_loss = loss(pred, target)semi_bat_loss.backward()optimizer.step()  # 更新参数 之后要梯度清零否则会累积梯度,因为下一轮数据要重新计算梯度optimizer.zero_grad()semi_loss += train_bat_loss.cpu().item()semi_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())print("半监督数据集的训练准确率为", semi_acc/train_loader.dataset.__len__())model.eval()with torch.no_grad():for batch_x, batch_y in val_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x)val_bat_loss = loss(pred, target)val_loss += val_bat_loss.cpu().item()val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())plt_val_loss.append(val_loss / val_loader.dataset.__len__())plt_val_acc.append(val_acc / val_loader.dataset.__len__())if epoch%3 == 0 and plt_val_acc[-1] > 0.6:#将模型训练至一定能力后,再进行半监督学习semi_loader = get_semi_loader(no_label_loader, model, device, thres)if val_acc > max_acc:torch.save(model, save_path)max_acc = val_lossprint('[%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f' % \(epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1], plt_val_acc[-1]))  # 打印训练结果。 注意python语法, %2.2f 表示小数位为2的浮点数, 后面可以对应。plt.plot(plt_train_loss)plt.plot(plt_val_loss)plt.title("loss")plt.legend(["train", "val"])plt.show()plt.plot(plt_train_acc)plt.plot(plt_val_acc)plt.title("acc")plt.legend(["train", "val"])plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/39864.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++初阶入门基础二——类和对象(中)

1类的默认成员函数 默认成员函数就是用户没有显式实现,编译器会自动生成的成员函数称为默认成员函数。一个类,我们不写的情况下编译器会默认生成以下6个默认成员函数,需要注意的是这6个中最重要的是前4个,最后两个取地址重载不重…

基于SSM框架的线上甜品销售系统(源码+lw+部署文档+讲解),源码可白嫖!

摘要 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计算机运行的软件更是受到各界的关注。加上现在人们已经步入信息时代,所以对于信息的宣传和管理就很关键。因此网上销售信息的…

3.25学习总结java 接口+内部类

JDK8以后新增的方法 可以将接口中静态方法和抽象方法中重复的部分抽离出来,作为私有方法,用去private修饰,此方法只为接口提供服务,不需要外界访问。 接口的应用 接口代表规则,是行为的抽象,想让哪个类拥有…

Linux--环境变量

ok,今天我们来学习Linux中的环境变量、地址空间、虚拟内存 环境变量 基本概念 环境变量(environmentvariables)⼀般是指在操作系统中⽤来指定操作系统运⾏环境的⼀些参数如:我们在编写C/C代码的时候,在链接的时候,从来不知道我…

Java 集合 List、Set、Map 区别与应用

一、核心特性对比 二、底层实现与典型差异 ‌List‌ ‌ArrayList‌:动态数组结构,随机访问快(O(1)),中间插入/删除效率低(O(n))‌‌LinkedList‌:双向链表结构,头尾操作…

基于 arco 的 React 和 Vue 设计系统

arco 是字节跳动出品的企业级设计系统,支持React 和 Vue。 安装模板工具 npm i -g arco-cli创建项目目录 cd someDir arco init hello-arco-pro? 请选择你希望使用的技术栈React❯ Vue? 请选择一个分类业务组件组件库Lerna Menorepo 项目❯ Arco Pro 项目看到以…

JVM-GC(G1)实践—GC异常定位、参数调整、GC更换

前言 如SpringBoot官方介绍所说的那样,从SpringBoot3.x开始支持的最低JDK版本为:JDK17(官方推荐使用BellSoft Liberica JDK),其对应的GC为G1。 本文笔者从应用实践的角度出发,记录一些关于GC的一些实践总…

吾爱出品,文件分类助手,高效管理您的 PC 资源库

在日常使用电脑的过程中,文件杂乱无章常常让人感到困扰。无论是桌面堆积如山的快捷方式,还是硬盘中混乱的音频、视频、文档等资源,都急需一种高效的整理方法。文件分类助手应运而生,它是一款文件管理工具,能够快速、智…

修改Flutter工程中Android项目minSdkVersion配置

Flutter项目开发过程中,根据模板自动生成.android项目,其中app>build.gradle中minSdkVersion的值是19,但是依赖了一个三方库,它的Android sdk 最小版本只支持到21,运行报错如下: 我们可以手动修改.andro…

如何设计一个订单号生成服务?应该考虑那些问题?

如何设计一个订单号生成服务?应该考虑那些问题? description: 在高并发的电商系统中,生成全局唯一的订单编号是关键。本文探讨了几种常见的订单编号生成方法,包括UUID、数据库自增、雪花算法和基于Redis的分布式组件,并…

Java学习总结-Stream流

啥是Stream流? 用于操作集合或数组的数据。他就像把数据化为成一条河流,我们可以对这条流操作,例如过滤。 获取Stream流 Stream流的常用方法: Stream流的终结方法: 收集Stream流

《TypeScript 面试八股:高频考点与核心知识点详解》

“你好啊!能把那天没唱的歌再唱给我听吗? ” 前言 因为主包还是主要学习js,ts浅浅的学习了一下,在简历中我也只会写了解,所以我写一些比较基础的八股,如果是想要更深入的八股的话还是建议找别人的。 Ts基…

热门面试题第14天|Leetcode 513找树左下角的值 112 113 路径总和 105 106 从中序与后序遍历序列构造二叉树 (及其扩展形式)以一敌二

找树左下角的值 本题递归偏难,反而迭代简单属于模板题, 两种方法掌握一下 题目链接/文章讲解/视频讲解:https://programmercarl.com/0513.%E6%89%BE%E6%A0%91%E5%B7%A6%E4%B8%8B%E8%A7%92%E7%9A%84%E5%80%BC.html 我们来分析一下题目&#…

Qt窗口控件之浮动窗口QDockWidget

浮动窗口QDockWidget QDockWidget 用于表示 Qt 中的浮动窗口,浮动窗口与工具栏类似,可以停靠在主窗口的上下左右位置,也可以单独拖出来作浮动窗口。 1. QDockWidget方法 方法说明setWidget(QWiget*)用于使浮动窗口能够被添加控件。setAllo…

Web前端之JavaScript的DOM操作冷门API

MENU 前言1、Element.checkVisibility()2、TreeWalker3、Node.compareDocumentPosition()4、scrollIntoViewIfNeeded()5、insertAdjacentElement()6、Range.surroundContents()7、Node.isEqualNode()8、document.createExpression()小结 前言 作为前端开发者,我们每…

【Linux-驱动开发-系统调用流程】

Linux-驱动开发-系统调用流程 ■ Linux-系统调用流程■ Linux-file_operations 结构体 ■ Linux-系统调用流程 ■ Linux-file_operations 结构体 在 Linux 内核文件 include/linux/fs.h 中有个叫做 file_operations 的结构体,此结构体就是 Linux 内核驱动操作函数集…

ToolsSet之:ASCII字符表和国际标准代码表

ToolsSet是微软商店中的一款包含数十种实用工具数百种细分功能的工具集合应用,应用基本功能介绍可以查看以下文章: Windows应用ToolsSet介绍https://blog.csdn.net/BinField/article/details/145898264 ToolsSet中Other菜单下的ASCII Table是一个ASCII…

C语言判断闰年相关问题

一、简单闰年问题引入 写一个判断年份是否为闰年的程序? 运行结果: 二、闰年问题进阶 使用switch语句根据用户输入的年份和月份,判断该月份有多少天? 第一种写法(判断年份写在switch的case的里面): 运行结果: 第二种解法(先判断闰年): 运行结果: 三、补充 switch中的ca…

基于Java的班级事务管理系统(源码+lw+部署文档+讲解),源码可白嫖!

摘要 随着世界经济信息化、全球化的到来和电子商务的飞速发展,推动了很多行业的改革。若想达到安全,快捷的目的,就需要拥有信息化的组织和管理模式,建立一套合理、畅通、高效的线上管理系统。当前的班级事务管理存在管理效率低下…

javaweb后端登录功能cookie session

登录功能 只需要这几个,用原来的返回太多用不上的信息,新写一个类只返回登录的结果 Ctrli 实现service的方法 和mapper相关的起名不用和业务一样 登录校验 登录校验思路 会话技术 cookie 创建cookie对象,响应给浏览器 服务端设置的…