Pytorch进阶教学——训练一个图像分类模型(GPU)

目录

1、前言 

2、数据集介绍

3、获取数据

4、创建网络

5、训练模型

6、测试模型

6.1、测试整个模型准确率

6.2、测试单张图片


1、前言 

  • 编写一个可以分类蚂蚁和蜜蜂图片的模型,使用数据集对卷积神经网络进行训练。训练后的模型可以对蚂蚁或蜜蜂的图片进行检测。
  • 使用anaconda新建一个虚拟环境,安装好pytorch。后续缺什么包就安装什么包即可。
  • 使用pycharm新建一个项目,配置好环境。

2、数据集介绍

  • 使用的数据集为蚂蚁和蜜蜂的图片,分为训练集和测试集
  • 【注】数据集下载地址。

3、获取数据

  • 代码中获取数据集使用的是txt文件,所以首先需要提取全部图片的地址和标签放入txt文件中。
  • 下述代码为python提取全部图片地址和标签导出为txt文件的脚本。(自行修改)
    • import os  # 导入os模块,用于操作文件路径等操作系统相关功能。def get_file_name(file_path, output_file, type):  # 绝对路径path_list = os.listdir(file_path)  # 列出指定路径下的所有文件和文件夹,并将结果存储在path_list中with open(output_file, 'a') as file:for filename in path_list:all_file_path = os.path.join(file_path, filename)  # 拼接路径file.write(all_file_path + ' ' + type + '\n')if __name__ == '__main__':ants_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\ants"bees_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\bees"output_file = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train.txt"get_file_name(ants_file_path, output_file, 'ants')get_file_name(bees_file_path, output_file, 'bees')
    •  
  • 将全部地址修改为相对地址。
    • 使用替换操作实现。例如:
  • 最后txt文件的内容如下:
  • 新建一个dataset.py文件。
    • # 读取数据
      import torch
      import torchvision.transforms as transforms
      from PIL import Image# 读取数据类
      class MyDataset(torch.utils.data.Dataset):  # 继承构建自定义数据集的基类def __init__(self, datatxt, datatransform):datas = open(datatxt, 'r').readlines()  # 按行读取,每行包含图像路径和标签self.images = []self.labels = []self.transform = datatransformfor data in datas:item = data.strip().split(' ')  # 去除首尾空格并按空格分割# 分别将图像路径和标签添加到self.images和self.labels列表中self.images.append(item[0])  # 路径self.labels.append(item[1])  # 标签returndef __len__(self):return len(self.images)# 获取数据集中的一个样本。接收一个索引item,根据索引获取对应的图像路径和标签def __getitem__(self, item):imagepath, label = self.images[item], self.labels[item]image = Image.open(imagepath)  # 打开图片return self.transform(image), label  # 返回转换后的图像和对应的标签# 用于测试
      if __name__ == '__main__':# 利用txt文件读取图片信息,txt文件包括图片路径和标签traintxt = './hymenoptera_data/train.txt'valtxt = './hymenoptera_data/val.txt'# 图片转换形式traindata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作])valdata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格transforms.Resize(48),  # 调整图像大小,调整为高度或宽度为48像素,另一边按比例调整transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据traindataset = MyDataset(traintxt, traindata_transfomer)valdataset = MyDataset(valtxt, valdata_transfomer)print("测试集:" + str(traindataset.__len__()))print("训练集:" + str(valdataset.__len__()))
  • 单独运行结果:(只用于测试)

4、创建网络

  • 新建一个net.py文件。
    • 其中创建了一个简单的三层卷积神经网络。
    • # 三层卷积神经网络
      import torch# 卷积神经网络类
      class SimpleConv3(torch.nn.Module):  # 继承创建神经网络的基类def __init__(self, classes):super(SimpleConv3, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(3, 16, 3, 2, 1)  # 输入通道3,输出通道16,3*3的卷积核,步长2,边缘填充1self.conv2 = torch.nn.Conv2d(16, 32, 3, 2, 1)  # 输入通道16,输出通道32,3*3的卷积核,步长2,边缘填充1self.conv3 = torch.nn.Conv2d(32, 64, 3, 2, 1)  # 输入通道32,输出通道64,3*3的卷积核,步长2,边缘填充1# 全连接层self.fc1 = torch.nn.Linear(2304, 100)self.fc2 = torch.nn.Linear(100, classes)def forward(self, x):# 第一次卷积x = torch.nn.functional.relu(self.conv1(x))  # relu为激活函数# 第二次卷积x = torch.nn.functional.relu(self.conv2(x))# 第三次卷积x = torch.nn.functional.relu(self.conv3(x))# 展开成一维向量x = x.view(x.size(0), -1)x = torch.nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 用于测试
      if __name__ == '__main__':inputs = torch.rand((1, 3, 48, 48))  # 生成一个随机的3通道、48x48大小的张量作为输入net = SimpleConv3(2)  # 二分类output = net(inputs)print(output)
  • 单独运行结果:(只用于测试)

5、训练模型

  • 新建一个train.py文件。
    • 其中可自行设置的参数都有标出。 
    • # 训练模型
      import matplotlibmatplotlib.use('TkAgg')
      import matplotlib.pyplot as plt
      from dataset import MyDataset
      from net import SimpleConv3
      import torch
      import torchvision.transforms as transforms
      from torch.optim import SGD  # 优化相关
      from torch.optim.lr_scheduler import StepLR  # 优化相关
      from sklearn import preprocessing  # 处理label# 图片转换形式
      traindata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])if __name__ == '__main__':traintxt = './hymenoptera_data/train.txt'valtxt = './hymenoptera_data/val.txt'# 加载数据traindataset = MyDataset(traintxt, traindata_transfomer)# 创建卷积神经网络net = SimpleConv3(2)  # 二分类# 使用GPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net.to(device)# 测试GPU是否能使用# print("The device is gpu later?:", next(net.parameters()).is_cuda)# print("The device is gpu,", next(net.parameters()).device)# 将数据提供给模型使用traindataloader = torch.utils.data.DataLoader(traindataset, batch_size=128, shuffle=True,num_workers=1)  # batch_size可以自行调节# 优化器optim = SGD(net.parameters(), lr=0.1, momentum=0.9)  # 使用随机梯度下降(SGD)作为优化器,学习率0.1,动量0.9,加速梯度下降过程,lr可自行调节criterion = torch.nn.CrossEntropyLoss()  # 使用交叉熵损失作为损失函数lr_step = StepLR(optim, step_size=200, gamma=0.1)  # 学习率调度器,动态调整学习率,每200个epoch调整一次,每次调整缩小为原来的0.1倍,step_size可自行调节epochs = 5  # 训练次数accs = []losss = []# 训练循环for epoch in range(0, epochs):batch = 0running_acc = 0.0  # 精度running_loss = 0.0  # 损失for data in traindataloader:batch += 1imputs, labels = data# 将标签从元组转换为tensor类型labels = preprocessing.LabelEncoder().fit_transform(labels)labels = torch.as_tensor(labels)# 利用GPU训练模型imputs = imputs.to(device)labels = labels.to(device)# 将数据输入至网络output = net(imputs)# 计算损失loss = criterion(output, labels)# 平均准确率acc = float(torch.sum(labels == torch.argmax(output, 1))) / len(imputs)# 累加损失和准确率,后面会除以batchrunning_acc += accrunning_loss += loss.data.item()optim.zero_grad()  # 清空梯度loss.backward()  # 反向传播optim.step()  # 更新参数lr_step.step()  # 更新优化器的学习率# 一次训练的精度和损失running_acc = running_acc / batchrunning_loss = running_loss / batchaccs.append(running_acc)losss.append(running_loss)print('epoch=' + str(epoch) + ' loss=' + str(running_loss) + ' acc=' + str(running_acc))# 保存模型torch.save(net, 'model.pth')  # 保存模型的权重和结构x = torch.randn(1, 3, 48, 48).to(device)  # # 生成一个随机的3通道、48x48大小的张量作为输入,新建的张量也要送到GPU中net = torch.load('model.pth')  # 从保存的.pth文件中加载模型net.train(False)  # 设置模型为推理模式,意味着不会进行梯度计算或反向传播torch.onnx.export(net, x, 'model.onnx')  # 使用ONNX格式导出模型# 接受模型net、示例输入x和导出的文件名model.onnx作为参数# 可视化结果fig = plt.figure()plot1, = plt.plot(range(len(accs)), accs)  # 创建一个图形对象plot1,绘制accs列表中的数据plot2, = plt.plot(range(len(losss)), losss)  # 创建另一个图形对象plot2,绘制losss列表中的数据plt.ylabel('epoch')  # 设置y轴的标签为epochplt.legend(handles=[plot1, plot2], labels=['acc', 'loss'])  # 创建图例,指定图表中不同曲线的标签plt.show()  # 展示所绘制的图表
  • 【注】本项目使用的是GPU训练模型。如果GPU可以获得,但是无法使用,可能是pytorch的版本不对,需要重新安装。
  • 运行结果:
  • 保存后的模型如下:

6、测试模型

6.1、测试整个模型准确率

  • 利用测试集,测试整个模型的准确率。
  • 新建一个test.py文件。
    • # 测试整个模型的准确率
      import torch
      import torchvision.transforms as transforms
      from dataset import MyDataset  # 您的数据集类
      from sklearn import preprocessing  # 处理label# 定义测试集的数据转换形式
      valdata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能transforms.CenterCrop(48),  # 中心裁剪图片,裁剪成高度和宽度均为48像素的部分transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])if __name__ == '__main__':valtxt = './hymenoptera_data/val.txt'  # 测试集数据路径# 加载测试集数据valdataset = MyDataset(valtxt, valdata_transfomer)# 加载已训练好的模型,利用GPU进行测试device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = torch.load('model.pth').to(device)net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播# 使用 DataLoader 加载测试集数据valdataloader = torch.utils.data.DataLoader(valdataset, batch_size=1, shuffle=False)correct = 0  # 被正确预测的样本数total = 0  # 测试样本数# 测试模型with torch.no_grad():for data in valdataloader:images, labels = data# 将标签从元组转换为tensor类型labels = preprocessing.LabelEncoder().fit_transform(labels)labels = torch.as_tensor(labels)# 利用GPU训练模型images, labels = images.to(device), labels.to(device)outputs = net(images)  # 输入图像并获取模型预测结果_, predicted = torch.max(outputs.data, 1)  # 获取预测值中最大概率的索引total += labels.size(0)  # 累计测试样本数量correct += (predicted == labels).sum().item()  # 计算正确预测的样本数量# 计算并输出模型在测试集上的准确率accuracy = 100 * correct / totalprint('Test Accuracy: {:.2f}%'.format(accuracy))
  • 运行结果:
    • 因为训练模型时只迭代了200次,所以准确率并不高。可以尝试提高训练次数,提高准确率。 

6.2、测试单张图片

  • 使用训练后的模型,对单张图片进行预测。
  • 新建一个testone.py文件。
    • import torch
      from PIL import Image
      import torchvision.transforms as transforms# 定义图片预处理转换
      image_transforms = transforms.Compose([transforms.Resize(60, antialias=True),  # 调整图像大小transforms.CenterCrop(48),  # 中心裁剪transforms.ToTensor(),  # 转为Tensor格式transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理
      ])# 定义类别映射字典
      class_mapping = {0: "ant",1: "bee"
      }# 加载已训练好的模型,利用GPU测试
      device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      net = torch.load('model.pth').to(device)
      net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播# 加载要测试的图片
      image_path = './hymenoptera_data/val/bees/26589803_5ba7000313.jpg'  # 图片路径
      input_image = Image.open(image_path)  # 加载图片
      input_tensor = image_transforms(input_image).unsqueeze(0)  # 对图片进行预处理转换,并增加 batch 维度# 将输入数据移动到GPU上
      input_tensor = input_tensor.to(device)# 使用模型进行预测
      with torch.no_grad():output = net(input_tensor)_, predicted = torch.max(output, 1)  # 在张量中沿指定维度找到最大值及其对应的索引# 输出预测结果
      predicted_class = predicted.item()  # 得到预测的标签
      predicted_label = class_mapping[predicted_class]  # 将标签转换为文字
      print(f"The predicted class for the image is: {predicted_label}")
  • 运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/212294.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

添加cpack install功能

修改最外层的CMakeLists.txt, 添加几行代码: # If GNUInstallDirs is not included, CMAKE_INSTALL_BINDIR is empty. include(GNUInstallDirs)# it must go before project in order to work set(CMAKE_INSTALL_PREFIX "${PROJECT_SOURCE_DIR}" CACHE …

Android平板还能编程?Ubuntu本地安装code-server远程编程写代码

文章目录 1.ubuntu本地安装code-server2. 安装cpolar内网穿透3. 创建隧道映射本地端口4. 安卓平板测试访问5.固定域名公网地址6.结语 1.ubuntu本地安装code-server 准备一台虚拟机,Ubuntu或者centos都可以,这里以VMwhere ubuntu系统为例 下载code server服务,浏览器…

Vector Quantized Diffusion Model for Text-to-Image Synthesis

Vector Quantized Diffusion Model for Text-to-Image Synthesis Shuyang Gu, University of Science and Technology of China, Microsoft, CVPR2022, Cited: 340, Code, Paper 1. 前言 我们提出了用于文本到图像生成的矢量量化扩散(Vector Quantized Diffusion Model&…

5G入门到精通 - 5G的十大关键技术

文章目录 一、网络切片二、自组织网络三、D2D技术四、低时延技术五、MIMO技术六、毫米波七、内容分发网络八、M2M技术九、频谱共享十、信息中心网络 一、网络切片 5G中的网络切片是一项关键技术,它允许将整个5G网络分割成多个独立的虚拟网络,每个虚拟网络…

vue2+typescript使用高德地图2.0版本

高德地图 webjs api 2.0官网教程 AMap.Driving使用说明 <div class"mmp"><div id"map" ref"mapcontainer"></div></div><script lang"ts"> //安全密钥 window._AMapSecurityConfig{securityJsCode: &qu…

添加新公司代码的配置步骤-Part1

原文地址&#xff1a;配置公司代码 概述 我们生活在一个充满活力的时代&#xff0c;公司经常买卖子公司。对于已经使用 SAP 的公司来说&#xff0c;增加收购就成为一个项目。我开发了一个电子表格&#xff0c;其中包含向您的结构添加新公司代码所需的所有配置更改。当然&…

Linux安全配置

进入ssh配置文件 vim /etc/ssh/sshd_config将port 22中的端口号改为5001 重启ssh服务 systemctl restart sshd拓展 sh与bash iptable与firewall ssh与sshd vps与ssh 参考&#xff1a; 【安全-SSH】SSH安全设置 - CSDN AppLinux VPS服务器SSH端口一键修改脚本​Linux脚本…

「Swift」取消UITableView起始位置在状态栏下方开始

前言&#xff1a;在写页面UI时发现&#xff0c;当隐藏了NavigationBar时&#xff0c;即使UITableView是从(0,0)进行布局&#xff0c;也会一直在手机状态栏下方进行展示布局&#xff0c;而我的想法是希望UITableView可以从状态栏处就进行展示布局 当前页面展示&#xff1a; 问题…

java设计模式学习之【装饰器模式】

文章目录 引言装饰器模式简介定义与用途实现方式 使用场景优势与劣势装饰器模式在Spring中的应用画图示例代码地址 引言 在日常生活中&#xff0c;我们常常对基本事物添加额外的装饰以增强其功能或美观。例如&#xff0c;给手机加一个保护壳来提升其防护能力&#xff0c;或者在…

JAVA网络编程——BIO、NIO、AIO深度解析

I/O 一直是很多Java同学难以理解的一个知识点&#xff0c;这篇帖子将会从底层原理上带你理解I/O&#xff0c;让你看清I/O相关问题的本质。 1、I/O的概念 I/O 的全称是Input/Output。虽常谈及I/O&#xff0c;但想必你也一时不能给出一个完整的定义。搜索了谷哥欠&#xff0c;发…

【UE5】使用场系统炸毁一堵墙

效果 步骤 1. 新建一个空白项目 2. 新建一个Basic关卡&#xff0c;然后添加一个第三人称游戏和初学者内容包到内容浏览器 3. 在场景中添加一堵墙 4. 选项模式选择“破裂” 点击新建 新建一个文件夹用于存储几何体集 点击“统一” 最小和最大Voronoi点数都设置为100 点击“破…

c++11计时器chrono库

去实现一下开始说的高精度计时器&#xff1a; #ifndef _TimerClock_hpp_ #define _TimerClock_hpp_#include <iostream> #include <chrono>using namespace std; using namespace std::chrono;class TimerClock { public:TimerClock(){update();}~TimerClock(){}v…

C语言指针详解上

1 野指针 int main01(){//野指针就是没有初始化的指针,指针的指向是随机的,不可以 操作野指针//int a 0;//指针p保存的地址一定是定义过的(向系统申请过的)int *p;//野指针*p 200;printf("%d\n",*p);system("pause");return 0;}2 空指针 空指针的作用…

力扣78. 子集(java 回溯解法)

Problem: 78. 子集 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 我们易知&#xff0c;本题目涉及到对元素的穷举&#xff0c;即我们可以使用回溯来实现。对于本题目我们应该较为注重回溯中的决策阶段&#xff1a; 由于涉及到对数组中元素的穷举&#xff0c;即在每…

Java网络编程 *TCP与UDP协议*

网络编程 什么是计算机网络? 把分布在不同地理区域的具有独立功能的计算机,通过通信设备与线路连接起来&#xff0c;由功能完善的软件实现资源共享和信息传递的系统 简单来说就是把不同地区的计算机通过设备连接起来,实现不同地区之前的数据传输 网络编程是干什么的? 网络…

Docker快速理解及简介

docker快速理解及简介 1.Docker为什么出现&#xff1f; 迁移一个项目时&#xff0c;运行文档、配置环境、运行环境、运行依赖包、操作系统发行版、内核等都需要重新安装配置&#xff0c;比较麻烦。 2.Docker是什么&#xff1f; Docker是基于Go语言实现的云开源项目。解决了运行…

JVM 字节码

JVM概述 问题引出 你是否也遇到过这些问题&#xff1f; 运行着的线上系统突然卡死&#xff0c;系统无法访问&#xff0c;甚至直接OOM&#xff01;想解决线上JVM GC问题&#xff0c;但却无从下手。新项目上线&#xff0c;对各种JVM参数设置一脸茫然&#xff0c;直接默认吧&…

Python:核心知识点整理大全2-笔记

目录 2.1 运行 hello_world.py 时发生的情况 第 2 章 变量和简单数据类型 2.2 变量 2.2.1 变量的命名和使用 2.2.2 使用变量时避免命名错误 2.3 字符串 2.3.1 使用方法修改字符串的大小写 在本章中&#xff0c;你将学习可在Python程序中使用的各种数据&#xff0c;还将学…

html/css中用float实现的盒子案例

运行效果&#xff1a; 代码部分&#xff1a; <!doctype html> <html> <head> <meta charset"utf-8"> <title>无标题文档</title> <style type"text/css">.father{width:300px; height:400px; background:gray;…

2个月拿下信息系统项目管理师攻略(攻略超级全)

信息系统项目管理师&#xff08;高项&#xff09;一次性过啦&#xff01;结合这次备考经验&#xff0c;给大家总结一下复习方法。 先上图&#xff0c;开心一下&#xff01; 一、我为什么选择了高项 为什么我会选信息系统项目管理师&#xff0c;也就是我们常说的高项。 原因1…