0基础学习PyTorch——时尚分类(Fashion MNIST)训练和推理

大纲

  • 环境准备
  • 安装依赖
  • 下载训练集
  • 训练
    • 定义模型
    • 训练
      • 加载训练集
      • 定义损失函数和优化器
      • 训练模型
      • 保存模型
      • 完整文件
  • 推理
    • 加载模型
    • 加载并预处理本地文件
    • 推理
    • 完整文件
  • 代码地址
  • 参考资料

时尚分类是PyTorch官方文档中推荐的案例。本文将拆解这个案例,进行部署以及测试。

环境准备

基础环境可以参考《0基础学习PyTorch——最小Demo》来进行部署。

安装依赖

torchvision 是 PyTorch 的一个官方库,专门用于计算机视觉任务。它提供了常用的数据集、模型架构和图像处理工具,简化了计算机视觉项目的开发过程。后续我们的数据都来源于该库。

source env.sh install torchvision

在这里插入图片描述

下载训练集

将下列内容保存为download.py。

# download.py
import torchvision# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', download=True)# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

然后运行这个文件

python download.py

在这里插入图片描述
此时目录结构如下
在这里插入图片描述

训练

定义模型

将下面内容保存为garmentclassifier.py。该文件会被训练和推理两个环节使用。

# garmentclassifier.py
import torch.nn as nn
import torch.nn.functional as F# 定义一个用于服装分类的卷积神经网络
class GarmentClassifier(nn.Module):def __init__(self):super(GarmentClassifier, self).__init__()# 定义第一个卷积层,输入通道数为1,输出通道数为6,卷积核大小为5x5self.conv1 = nn.Conv2d(1, 6, 5)# 定义最大池化层,池化窗口大小为2x2self.pool = nn.MaxPool2d(2, 2)# 定义第二个卷积层,输入通道数为6,输出通道数为16,卷积核大小为5x5self.conv2 = nn.Conv2d(6, 16, 5)# 定义第一个全连接层,输入大小为16*4*4,输出大小为120self.fc1 = nn.Linear(16 * 4 * 4, 120)# 定义第二个全连接层,输入大小为120,输出大小为84self.fc2 = nn.Linear(120, 84)# 定义第三个全连接层,输入大小为84,输出大小为10(对应10个类别)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 通过第一个卷积层和ReLU激活函数,然后通过最大池化层x = self.pool(F.relu(self.conv1(x)))# 通过第二个卷积层和ReLU激活函数,然后通过最大池化层x = self.pool(F.relu(self.conv2(x)))# 展平张量,从多维张量变为二维张量x = x.view(-1, 16 * 4 * 4)# 通过第一个全连接层和ReLU激活函数x = F.relu(self.fc1(x))# 通过第二个全连接层和ReLU激活函数x = F.relu(self.fc2(x))# 通过第三个全连接层(输出层)x = self.fc3(x)return x

训练

加载训练集

这次我们直接从本地加载训练集,但是需要做归一化处理。

from datetime import datetime
import torch
import torchvision
import torchvision.transforms as transforms
from garmentclassifier import GarmentClassifier# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # 对图像的每个通道进行标准化,使得每个通道的像素值具有零均值和单位标准差# 加载FashionMNIST训练数据集,并应用定义的图像转换操作
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform)# 创建数据加载器,用于批量加载训练数据,batch_size为4,数据顺序随机打乱
trainloader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)

定义损失函数和优化器

# 实例化模型
model = GarmentClassifier()
# 定义损失函数为交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器为随机梯度下降(SGD),学习率为0.001,动量为0.9
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

训练模型

# 训练模型,训练2个epoch
for epoch in range(2):running_loss = 0.0  # 初始化累计损失# 枚举数据加载器中的数据,i是批次索引,data是当前批次的数据for i, data in enumerate(trainloader, 0):inputs, labels = data  # 获取输入数据和对应的标签optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播,计算模型输出loss = loss_fn(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失# 每2000个批次打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')running_loss = 0.0  # 重置累计损失

保存模型

# 获取当前时间戳,格式为 'YYYYMMDD_HHMMSS'
timestamp = datetime.now().strftime('%Y%m%d%H%M%S.pth')# 定义模型保存路径,包含时间戳
model_path = 'model_{}'.format(timestamp)      # 保存模型的状态字典到指定路径
torch.save(model.state_dict(), model_path)

完整文件

from datetime import datetime
import torch
import torchvision
import torchvision.transforms as transforms
from garmentclassifier import GarmentClassifier# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # 对图像的每个通道进行标准化,使得每个通道的像素值具有零均值和单位标准差# 加载FashionMNIST训练数据集,并应用定义的图像转换操作
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform)# 创建数据加载器,用于批量加载训练数据,batch_size为4,数据顺序随机打乱
trainloader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)# 实例化模型
model = GarmentClassifier()
# 定义损失函数为交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器为随机梯度下降(SGD),学习率为0.001,动量为0.9
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练模型,训练2个epoch
for epoch in range(2):running_loss = 0.0  # 初始化累计损失# 枚举数据加载器中的数据,i是批次索引,data是当前批次的数据for i, data in enumerate(trainloader, 0):inputs, labels = data  # 获取输入数据和对应的标签optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播,计算模型输出loss = loss_fn(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失# 每2000个批次打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')running_loss = 0.0  # 重置累计损失# 获取当前时间戳,格式为 'YYYYMMDD_HHMMSS'
timestamp = datetime.now().strftime('%Y%m%d%H%M%S.pth')# 定义模型保存路径,包含时间戳
model_path = 'model_{}'.format(timestamp)      # 保存模型的状态字典到指定路径
torch.save(model.state_dict(), model_path)

执行该文件,我们会得到一个后缀为pth的模型文件。

推理

加载模型

我们加载上一步创建的模型。

import os
import glob
import torch
import torchvision.transforms as transforms
from PIL import Image
from datetime import datetime
from garmentclassifier import GarmentClassifierdef get_latest_model_path(directory, pattern="model_*.pth"):# 获取目录下所有符合模式的文件model_files = glob.glob(os.path.join(directory, pattern))if not model_files:raise FileNotFoundError("No model files found in the directory.")# 找到最新的模型文件latest_model_file = max(model_files, key=os.path.getmtime)return latest_model_file# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小为28x28transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练好的模型
model = GarmentClassifier()
model_path = get_latest_model_path('./')  # 获取最新的模型文件
model.load_state_dict(torch.load(model_path, weights_only=False)) # 加载模型参数
model.eval()  # 设置模型为评估模式

加载并预处理本地文件

在这里插入图片描述

# 从本地加载图像
image_path = 'shoe.jpg'  # 替换为实际的图像路径
image = Image.open(image_path).convert('L')  # 将图像转换为灰度图# 预处理图像
image = transform(image)
image = image.unsqueeze(0)  # 增加一个批次维度

我们使用transform进行归一化处理。

推理

# 推理(预测)
with torch.no_grad():  # 在推理过程中不需要计算梯度outputs = model(image)  # 前向传播,计算模型输出_, predicted = torch.max(outputs, 1)  # 获取预测结果# 定义类别名称
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')# 打印预测结果
print(f'Predicted label: {classes[predicted.item()]}')

由于推理出来的是索引号,为了方便解读,我们将类型映射打印出来。

完整文件

import os
import glob
import torch
import torchvision.transforms as transforms
from PIL import Image
from datetime import datetime
from garmentclassifier import GarmentClassifierdef get_latest_model_path(directory, pattern="model_*.pth"):# 获取目录下所有符合模式的文件model_files = glob.glob(os.path.join(directory, pattern))if not model_files:raise FileNotFoundError("No model files found in the directory.")# 找到最新的模型文件latest_model_file = max(model_files, key=os.path.getmtime)return latest_model_file# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小为28x28transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练好的模型
model = GarmentClassifier()
model_path = get_latest_model_path('./')  # 获取最新的模型文件
model.load_state_dict(torch.load(model_path, weights_only=False)) # 加载模型参数
model.eval()  # 设置模型为评估模式# 从本地加载图像
image_path = 'shoe.jpg'  # 替换为实际的图像路径
image = Image.open(image_path).convert('L')  # 将图像转换为灰度图# 预处理图像
image = transform(image)
image = image.unsqueeze(0)  # 增加一个批次维度# 推理(预测)
with torch.no_grad():  # 在推理过程中不需要计算梯度outputs = model(image)  # 前向传播,计算模型输出_, predicted = torch.max(outputs, 1)  # 获取预测结果# 定义类别名称
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')# 打印预测结果
print(f'Predicted label: {classes[predicted.item()]}')

执行这个文件,我们看到推理结果是:Sandal(凉鞋)。
在这里插入图片描述

代码地址

https://github.com/f304646673/deeplearning/tree/main/FashionMNIST

参考资料

  • https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/432196.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电路板上电子元件检测系统源码分享

电路板上电子元件检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comp…

SpringCloud源码:客户端分析(二)- 客户端源码分析

背景 我们继续分析EurekaClient的两个自动化配置类: 自动化配置类功能职责EurekaClientAutoConfiguration配置EurekaClient确保了Eureka客户端能够正确地:- 注册到Eureka服务端- 周期性地发送心跳信息来更新服务租约- 下线时通知Eureka服务端- 获取服务实…

TypeScript 设计模式之【建造者模式】

文章目录 **建造者模式**:打造你的梦想之屋建造者的秘密建造者有什么利与害?如何使用建造者搭建各种房子代码实现案例建造者模式的主要优点建造者模式的主要缺点建造者模式的适用场景总结 建造者模式:打造你的梦想之屋 假设你想要一栋完美的…

SpringBoot代码实战(MyBatis-Plus+Thymeleaf)

构建项目 修改pom.xml文件&#xff0c;添加其他依赖以及设置 <!--MyBatis-Plus依赖--><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-spring-boot3-starter</artifactId><version>3.5.6</version><…

LiveGBS流媒体平台GB/T28181功能-支持电子放大拉框放大直播视频拉框放大录像视频流拉框放大电子放大

LiveGBS流媒体平台GB/T28181功能-支持电子放大拉框放大直播视频拉框放大录像视频流拉框放大电子放大 1、直播播放2、录像播放3、搭建GB28181视频直播平台 1、直播播放 国标设备-》查看通道-》播放 &#xff0c;左键单击可以拉取矩形框&#xff0c;放大选中的范围&#xff0c;释…

序列化流(对象操作输出流)反序列化流(对象操作输入流)

可以把Java中的对象写到本地文件中 序列化流&#xff08;对象操作输出流&#xff09; 构造方法 成员方法 使用对象输出流将对象保存到文件会出现NotSerializableException异常 解决方案&#xff1a;需要让Javabean类实现Serializable接口 Student package myio;import java.…

家政服务预约系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;客户管理&#xff0c;员工管理&#xff0c;家政服务管理&#xff0c;服务预约管理&#xff0c;员工风采管理&#xff0c;客户需求管理&#xff0c;接单信息管理 微信端账号功能包括&#xff1a;系统首…

MySQL_子查询

课 程 推 荐我 的 个 人 主 页&#xff1a;&#x1f449;&#x1f449; 失心疯的个人主页 &#x1f448;&#x1f448;入 门 教 程 推 荐 &#xff1a;&#x1f449;&#x1f449; Python零基础入门教程合集 &#x1f448;&#x1f448;虚 拟 环 境 搭 建 &#xff1a;&#x1…

力扣最热一百题——寻找重复数(中等)

目录 题目链接&#xff1a;287. 寻找重复数 - 力扣&#xff08;LeetCode&#xff09; 题目描述 示例 提示&#xff1a; 解法一&#xff1a;暴力搜寻 Java写法&#xff1a; 运行时间 解法二&#xff1a;排序搜寻 Java写法&#xff1a; 运行时间 C写法&#xff1a; 运…

2024/9/26 英语每日一段

In part, that’s because it’s harder to empathize with someone who feels distant or unknown than a close loved one. “The more shared experiences you have with someone, the more of a rich, nuanced representation you can draw on,” Cameron says. But empath…

【Java网络编程】使用Tcp和Udp实现一个小型的回声客户端服务器程序

网络编程的概念 Java中的网络编程是指使用Java语言及其库创建和管理网络应用程序的过程。这一过程使得不同的计算机可以通过网络进行通信和数据交换。Java提供了一系列强大的API&#xff08;应用程序编程接口&#xff09;来支持网络编程&#xff0c;主要涉及以下几个概念&…

简易STL实现 | 红黑树的实现

1、原理 红黑树&#xff08;Red-Black Tree&#xff09;是一种自平衡的二叉搜索树 红黑树具有以下特性&#xff0c;这些特性保持了树的平衡&#xff1a; 节点颜色&#xff1a; 每个节点要么是红色&#xff0c;要么是黑色根节点颜色&#xff1a; 根节点是黑色的。叶子节点&…

【stm32】TIM定时器输出比较-PWM驱动LED呼吸灯/舵机/直流电机

TIM定时器输出比较 一、输出比较简介1、OC&#xff08;Output Compare&#xff09;输出比较2、PWM简介3、输出比较通道(高级)4、输出比较通道(通用)5、输出比较模式6、PWM基本结构配置步骤&#xff1a;程序代码&#xff1a;PWM驱动LED呼吸灯 7、参数计算8、舵机简介程序代码&am…

【笔记】KaiOS 系统框架和应用结构(APP界面逻辑)

KaiOS系统框架 最早自下而上分成Gonk-Gecko-Gaia层,代码有同名的目录,现在已经不用这种称呼。 按照官网3.0的版本迭代介绍,2.5->3.0已经将系统更新成如下部分: 仅分为上层web应用和底层平台核心,通过WebAPIs连接上下层,这也是kaios系统升级变更较大的部分。 KaiOS P…

括号匹配问题 -------------

1.题目说明&#xff1a; 给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &#xff0c;判断字符串是否有效。 有效字符串需满足&#xff1a; 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有…

Jenkins入门:从搭建到部署第一个Springboot项目(踩坑记录)

本文讲述在虚拟机环境下(模拟服务器)&#xff0c;使用docker方式搭建jenkins&#xff0c;并部署一个简单的Springboot项目。仅记录关键步骤和遇到的坑&#xff0c;后续再进行细节补充。 一、环境准备和基础工具安装 1. 环境 系统环境为本机vmware创建的Ubuntu24.04。 2. yum…

【C++】STL--string(下)

1.string类对象的修改操作 erase&#xff1a;指定位置删除 int main() {string str1("hello world");str1.push_back(c);//尾插一个ccout << str1 << endl;string str2;str2.append("hello"); // 在str后追加一个字符"hello"cout…

CNN-LSTM预测 | MATLAB实现CNN-LSTM卷积长短期记忆神经网络时间序列预测

CNN-LSTM预测 | MATLAB实现CNN-LSTM卷积长短期记忆神经网络时间序列预测 目录 CNN-LSTM预测 | MATLAB实现CNN-LSTM卷积长短期记忆神经网络时间序列预测预测效果基本介绍模型描述程序设计参考资料预测效果 基本介绍 本次运行测试环境MATLAB2020b 提出一种包含卷积神经网络和长短…

多机部署,负载均衡-LoadBalance

文章目录 多机部署,负载均衡-LoadBalance1. 开启多个服务2. 什么是负载均衡负载均衡的实现客户端负载均衡 3. Spring Cloud LoadBalance快速上手使用Spring Cloud LoadBalance实现负载均衡修改IP,端口号为服务名称启动多个服务 负载均衡策略自定义负载均衡策略 LoadBalance原理…

c++模拟真人鼠标轨迹算法

一.鼠标轨迹算法简介 鼠标轨迹底层实现采用 C / C语言&#xff0c;利用其高性能和系统级访问能力&#xff0c;开发出高效的鼠标轨迹模拟算法。通过将算法封装为 DLL&#xff08;动态链接库&#xff09;&#xff0c;可以方便地在不同的编程环境中调用&#xff0c;实现跨语言的兼…