Pytorch实现CIFAR10训练模型

文章目录

  • 简述
  • 模型结构
  • 模型参数、优化器、损失函数
    • 参数初始化
    • 优化器
    • 损失函数
  • 模型训练、测试集预测、模型保存、日志记录
    • 训练
    • 测试集测试
    • 模型保存
    • 模型训练完整代码
  • tensorboard训练可视化结果
    • train_loss
    • 测试准确率
    • 测试集loss
  • 模型应用
    • 模型独立应用代码`api.py`
    • 预测结果

简述

使用pytorch实现一个用于训练CIFAR10的模型,在训练过程中使用CIFAR10的测试数据集记录准确度。训练结束后,搜集一些图片,单独实现对训练后模型的应用代码。

另外会在文中尽量给出各种用法的官方文档链接。

代码分为:

  1. 模型训练代码train.py,包含数据加载、模型封装、训练、tensorboard记录、模型保存等;
  2. 模型应用代码api.py,包含对训练所保存模型的加载、数据准备、结果预测等;

注意:

本文目的是使用pytorch来构建一个结构完善的模型,体现出pytorch的各种功能函数、模型设计理念,来学习深度学习,而非训练一个高精度的分类识别模型。

不足:

  1. 参数初始化或许可以考虑kaiming(因为用的是ReLU);
  2. 可以加上k折交叉验证;
  3. 训练时可以把batch_size的图片加入tensorboard,文中batch_size=256,若每个batch_size都加的话数据太多了,所以文中是每逢整百的训练次数时记录一下该批次的loss值,加图片的话可以在该代码处添加;

模型结构

来源:https://www.researchgate.net/profile/Yiren-Zhou-6/publication/312170477/figure/fig1/AS:448817725218816@1484017892071/Structure-of-LeNet-5.png

在这里插入图片描述

在上述图片基础上增加了nn.BatchNorm2dnn.ReLU以及nn.Dropout,最终结构如下:

layers = nn.Sequential(  # shape(3,32,32) -> shape(32,32,32)  nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,32,32) -> shape(32,16,16)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,16,16) -> shape(32,16,16)  nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,16,16) -> shape(32,8,8)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,8,8) -> shape(64,8,8)  nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(64),  nn.ReLU(),  # shape(64, 8, 8) -> shape(64,4,4)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(64,4,4) -> shape(64 * 4 * 4,)  nn.Flatten(),  nn.Linear(64 * 4 * 4, 64),  nn.ReLU(),  nn.Dropout(0.5),  nn.Linear(64, 10)  
)

可以看看使用tensorboard的writer.add_graph函数实现的模型结构图:

在这里插入图片描述

模型参数、优化器、损失函数

参数初始化

模型参数使用nn.init.normal_作初始化,但模型中存在ReLU,应考虑使用kaiming He初始化。

apply函数:Module — PyTorch 2.4 documentation

参数初始化函数:torch.nn.init — PyTorch 2.4 documentation

def init_normal(m):  # 考虑使用kaiming  if m is nn.Linear:  nn.init.normal_(m.weight, mean=0, std=0.01)  nn.init.zeros_(m.bias)# 定义模型、数据初始化  
net = CIFAR10Net()  
net.apply(init_normal)

优化器

优化器使用Adam,即MomentumAdaGrad的结合。

文档:Adam — PyTorch 2.4 documentation

# 优化器  
weight_decay = 0.0001optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

损失函数

分类任务,自然是用交叉熵损失函数了。

loss_fn = nn.CrossEntropyLoss()

模型训练、测试集预测、模型保存、日志记录

注意,代码前面部分代码有定义
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

训练

net.train()  
for images, labels in train_loader:  images, labels = images.to(device), labels.to(device)  outputs = net(images)  loss = loss_fn(outputs, labels)  # 优化器处理  optimizer.zero_grad()  loss.backward()  optimizer.step()  total_train_step += 1  if total_train_step % 100 == 0:  print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  writer.add_scalar('train_loss', loss.item(), total_train_step)  current_time = time.time()  writer.add_scalar('train_time', current_time-start_time, total_train_step)

测试集测试

net.eval()  
total_test_loss = 0  
total_test_acc = 0  # 整个测试集正确个数  
with torch.no_grad():  for images, labels in test_loader:  images, labels = images.to(device), labels.to(device)  outputs = net(images)  loss = loss_fn(outputs, labels)  total_test_loss += loss.item()  accuracy = (outputs.argmax(1) == labels).sum()  total_test_acc += accuracy  print(f'整个测试集loss值和: {total_test_loss:.4f}, batch_size: {batch_size}')  
print(f'整个测试集正确率: {(total_test_acc / test_data_size) * 100:.4f}%')  
writer.add_scalar('test_loss', total_test_loss, epoch + 1)  
writer.add_scalar('test_acc', (total_test_acc / test_data_size) * 100, epoch + 1)

模型保存

torch.save(net.state_dict(),  './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))

模型训练完整代码

train.py

import torch  
import torchvision  
from torch.utils.tensorboard import SummaryWriter  
from torchvision import transforms  
from torch.utils import data  
from torch import nn  
import time  
from datetime import datetime  def load_data_CIFAR10(resize=None):  """  下载 CIFAR10 数据集,然后将其加载到内存中  transforms.ToTensor() 转换为形状为C x H x W的FloatTensor,并且会将像素值从[0, 255]缩放到[0.0, 1.0]  """    trans = [transforms.ToTensor()]  if resize:  trans.insert(0, transforms.Resize(resize))  trans = transforms.Compose(trans)  cifar_train = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=False)  cifar_test = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=False)  return cifar_train, cifar_test  class CIFAR10Net(torch.nn.Module):  def __init__(self):  super(CIFAR10Net, self).__init__()  layers = nn.Sequential(  # shape(3,32,32) -> shape(32,32,32)  nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,32,32) -> shape(32,16,16)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,16,16) -> shape(32,16,16)  nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,16,16) -> shape(32,8,8)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,8,8) -> shape(64,8,8)  nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(64),  nn.ReLU(),  # shape(64, 8, 8) -> shape(64,4,4)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(64,4,4) -> shape(64 * 4 * 4,)  nn.Flatten(),  nn.Linear(64 * 4 * 4, 64),  nn.ReLU(),  nn.Dropout(0.5),  nn.Linear(64, 10)  )  self.layers = layers  def forward(self, x):  return self.layers(x)  def init_normal(m):  # 考虑使用kaiming  if m is nn.Linear:  nn.init.normal_(m.weight, mean=0, std=0.01)  nn.init.zeros_(m.bias)  if __name__ == '__main__':  # 超参数  epochs = 6  batch_size = 256  learning_rate = 0.01  num_workers = 0  weight_decay = 0  # 数据记录  total_train_step = 0  total_test_step = 0  train_loss_list = list()  test_loss_list = list()  train_acc_list = list()  test_acc_list = list()  # 准备数据集  train_data, test_data = load_data_CIFAR10()  train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)  test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)  train_data_size = len(train_data)  test_data_size = len(test_data)  print(f'训练测试集长度: {train_data_size}, 测试数据集长度: {test_data_size}, batch_size: {batch_size}\n')  # device = torch.device("cpu")  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  print(f'\ndevice: {device}')  # 定义模型、数据初始化  net = CIFAR10Net().to(device)  # net.apply(init_normal)  # 损失函数  loss_fn = nn.CrossEntropyLoss().to(device)  # 优化器  optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)  # now_time = datetime.now()  # now_time = now_time.strftime("%Y%m%d-%H%M%S")    # tensorboard    writer = SummaryWriter('./train_logs')  # 随便定义个输入, 好使用add_graph  tmp = torch.rand((batch_size, 3, 32, 32)).to(device)  writer.add_graph(net, tmp)  start_time = time.time()  for epoch in range(epochs):  print('------------Epoch {}/{}'.format(epoch + 1, epochs))  # 训练  net.train()  for images, labels in train_loader:  images, labels = images.to(device), labels.to(device)  outputs = net(images)  loss = loss_fn(outputs, labels)  # 优化器处理  optimizer.zero_grad()  loss.backward()  optimizer.step()  total_train_step += 1  if total_train_step % 100 == 0:  print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  writer.add_scalar('train_loss', loss.item(), total_train_step)  current_time = time.time()  writer.add_scalar('train_time', current_time-start_time, total_train_step)  # 测试  net.eval()  total_test_loss = 0  total_test_acc = 0  # 整个测试集正确个数  with torch.no_grad():  for images, labels in test_loader:  images, labels = images.to(device), labels.to(device)  outputs = net(images)  loss = loss_fn(outputs, labels)  total_test_loss += loss.item()  accuracy = (outputs.argmax(1) == labels).sum()  total_test_acc += accuracy  print(f'整个测试集loss值和: {total_test_loss:.4f}, batch_size: {batch_size}')  print(f'整个测试集正确率: {(total_test_acc / test_data_size) * 100:.4f}%')  writer.add_scalar('test_loss', total_test_loss, epoch + 1)  writer.add_scalar('test_acc', (total_test_acc / test_data_size) * 100, epoch + 1)  torch.save(net.state_dict(),  './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))  writer.close()

tensorboard训练可视化结果

train_loss

纵轴为每个batch_size损失值,横轴为训练次数,其中batch_size = 256。

在这里插入图片描述

测试准确率

纵轴为整个CIFAR10测试集的准确率(%),横轴为epoch,其中epochs=50。

在这里插入图片描述

测试集loss

纵轴为CIFAR10整个测试集的每个batch_size的loss之和,batch_size = 256。横轴为epoch,其中epochs=50。

在这里插入图片描述

模型应用

模型训练过程中,每个epoch保存一次模型。

torch.save(net.state_dict(),  './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))  

这里实现一个,将保存的模型加载,并对自行搜集的图片进行预测。

项目结构:

  1. ./autodl_save/cuda_params_acc_75.pth:训练时保存的模型参数文件;

  2. ./test_images:网上搜集的卡车、狗、飞机、船图片,大小不一,保存时未作处理,如下:
    在这里插入图片描述

  3. api.py:实现图片的预处理(裁剪、ToTensor、封装为数据集等)、模型加载、图片推理等;

模型独立应用代码api.py

import os  import torch  
import torchvision  
from PIL import Image  
from torch import nn  class CIFAR10Net(torch.nn.Module):  def __init__(self):  super(CIFAR10Net, self).__init__()  layers = nn.Sequential(  # shape(3,32,32) -> shape(32,32,32)  nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,32,32) -> shape(32,16,16)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,16,16) -> shape(32,16,16)  nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(32),  nn.ReLU(),  # shape(32,16,16) -> shape(32,8,8)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(32,8,8) -> shape(64,8,8)  nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(64),  nn.ReLU(),  # shape(64, 8, 8) -> shape(64,4,4)  nn.MaxPool2d(kernel_size=2, stride=2),  # shape(64,4,4) -> shape(64 * 4 * 4,)  nn.Flatten(),  nn.Linear(64 * 4 * 4, 64),  nn.ReLU(),  nn.Dropout(0.5),  nn.Linear(64, 10)  )  self.layers = layers  def forward(self, x):  return self.layers(x)  def build_data(images_dir):  image_list = os.listdir(images_dir)  image_paths = []  for image in image_list:  image_paths.append(os.path.join(images_dir, image))  transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),  torchvision.transforms.ToTensor()])  # 存储转换后的张量  images_tensor = []  for image_path in image_paths:  try:  # 加载图像并转换为 RGB(如果它已经是 RGB,这步是多余的)  image_pil = Image.open(image_path).convert('RGB')  # 应用转换并添加到列表中  images_tensor.append(transform(image_pil))  except IOError:  print(f"Cannot open {image_path}. Skipping...")  # 转换列表为单个张量,如果需要的话  # 注意:这里假设所有图像都被成功加载和转换  if images_tensor:  # 使用 torch.stack 来合并张量列表  images_tensor = torch.stack(images_tensor)  else:  # 如果没有图像,返回一个空的张量或根据需要处理  images_tensor = torch.empty(0, 3, 32, 32)  return images_tensor, image_list  def predict(state_dict_path, image):  net = CIFAR10Net()  net.load_state_dict(torch.load(state_dict_path))  net.cuda()  with torch.no_grad():  image = image.cuda()  output = net(image)  return output  if __name__ == '__main__':  images, labels = build_data("./test_images")  outputs = predict("./autodl_save/cuda_params_acc_75.pth", images)  # 选取结果(即得分最大的下标)  res = outputs.argmax(dim=1)  kinds = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']  for i in range(len(res)):  classes_idx = res[i]  print(f'文件(正确标签): {labels[i]},  预测结果: {classes_idx}, {kinds[classes_idx]}\n')

预测结果

7个识别出4个。

在这里插入图片描述

注意这个索引和标签的对应关系可以从数据集中查看。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/408610.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Axure设计之三级菜单导航教程(中继器)

中继器作为复杂的元件,通常被用来制作“高保真”的动态原型,以达到良好的视觉效果和交互效果。本文将教大家通过AxureRP9工具如何使用中继器设计三级菜单导航。 一、案例效果 原型预览:https://1zvcwx.axshare.com 主要效果: 1…

数据结构(Java实现):链表与LinkedList

文章目录 1. 单向链表1.1 链表的概念及结构1.2 链表的实现1.2.1 单向链表类和节点1.2.2 打印每个节点的值1.2.3 计算链表长度1.2.4 头插节点1.2.5 尾插节点1.2.6 在指定下标插入新节点1.2.7 判断是否存在某个节点1.2.8 移除某个节点1.2.9 移除所有指定节点1.2.10 清空链表1.2.1…

redis | 认识非关系型数据库Redis的哈希数据类型

Redis 非关 kv型 哈希通用命令python 操作hash应用场景 数据类型 数据类型丰富,字符串strings,散列hashes,列表lists,集合sets,有序集合sorted sets等等 哈希 定义 1、由field和关联的value组成的键值对 类似于python的键值对 2、field和value.是字符…

一文学会Shell中case语句和函数

大家好呀!今天简单聊一聊Shell中的case语句与函数。在多选择情况下使用case语句将非常方便,同时,函数的学习和使用对于学好一门编程语言也是非常重要的。 一、case语句 case语句为多选择语句。可以用case语句匹配一个值与一个模式&#xff0c…

OpenCV绘图函数详解及其用法示例

MFC类库中的CDC类有划线,画矩形,画椭圆,画多边形,文字等绘图函数,OpenCV也有类似的绘图函数。二者的区别在于MFC画图是在一定的区域内绘制图形,而OpenCV则是在图像上绘制,主要用于图像标注。 OpenCV的常用绘图函数有arrowedLine,circle ,drawContours, drawMarker, dra…

AI数字时代客户体验白皮书5G云算力网络云网终端AIGC人工智能宽带政企物联网专线 IDC智慧城市专家学者教授培训讲师分享

客户体验的时代已然来临 在过去的几十年里,中国企业逐步从产品驱动转向市场驱动,从规模竞争走向创新竞争。然而,随着市场竞争的白热化和产品、服务的高度同质化,企业之间的差异化逐渐被削弱,传统的价格战、渠道战已经…

layui table表单 checkbox选中一个其它也要选中

当我们选中其中一个商品的时候同类型的商品状态也要跟着改变 所以要在表单加载完成后去监听checkbox ,done:function (res) {console.log(详情表格数据,res)tableDetailList res.data;// 监听表格复选框选择table.on(checkbox( INST_SELECTORS.instLayFilters.unpaidTableDe…

Python优化算法13——飞蛾扑火优化算法(MFO)

科研里面优化算法都用的多,尤其是各种动物园里面的智能仿生优化算法,但是目前都是MATLAB的代码多,python几乎没有什么包,这次把优化算法系列的代码都从底层手写开始。 需要看以前的优化算法文章可以参考:Python优化算…

用4种不同视角理解矩阵乘法

目录 1. 背景 2. 线性方程组视角(向量点积视角) 3. 列向量观点视角 4. 向量变换视角(矩阵函数) 5. 坐标变换视角 1. 背景 矩阵诞生于线性方程组的求解,最基本的运算方法来自于高斯消元法,所以矩阵整个…

Linux 离线安装docker和docker-compose

前言 公司有 docker 和 docker-compose 离线包安装部署的需求,本文应运而生撰写时间:2024-06-07(初稿) 1 应用版本 docker:20.10.7, build f0df350docker-compose:1.25.1 2 物料准备 服务器账号/密码d…

[数据集][目标检测]电力场景输电线防震锤检测数据集VOC+YOLO格式2721张2类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2721 标注数量(xml文件个数):2721 标注数量(txt文件个数):2721 标注…

《javaEE篇》--线程池

线程池是什么 线程的诞生是因为进程创建和销毁的成本太大,但是也是相对而言,如果频繁的创建和销毁线程那么这个成本就不能忽略了。 一般有两种方法来进一步提高效率,一种是协程(这里不多做讨论),另一种就是线程池 假如说有一个学校食堂窗口…

智能控制,高效节能。ZLG致远电子能源智慧管理解决方案

面对楼宇及建筑群能源管理与设备控制的复杂需求,ZLG致远电子推出了一套能源智慧管理解决方案。该方案集设备管理、任务调度和数据可视化于一体,不仅实现数据的实时监控与分析,还助力系统节能降耗。 ZLG致远电子能源智慧管理解决方案 在ZLG致…

ShareSDK 企业微信

本篇文档主要讲解如何使用企业微信并进行分享和授权。 创建应用 登录企业微信并通过企业认证。选择应用管理 > 应用 >创建应用。编辑应用信息。配置授权登录信息。 以下为创建过程示例,图中信息仅为示例,创建时请按照真实信息填写,否…

如何查看ubuntu版本

在当前的技术环境中,了解操作系统的具体版本对于用户来说至关重要。这不仅能确保软件兼容性,还有助于进行系统管理和故障排查。对于使用Ubuntu系统的用户来说,有几种不同的方法可以查看当前系统的版本。下面将详细介绍如何查看您的Ubuntu系统…

Spring Boot(快速上手)

Spring Boot 零、环境配置 1. 创建项目 2. 热部署 添加依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-devtools</artifactId><optional>true</optional> </dependency&…

polarctf靶场[WEB]cookie欺骗、upload、签到

[web]cookie欺骗 考点&#xff1a;cookie值 工具&#xff1a;Burp Suite抓包 根据题目提示&#xff0c;cookie欺骗&#xff0c;所以要在cookie值寻找关键 进入网页之后&#xff0c;说只有admin用户才能得到flag&#xff0c;而我们此时只属于普通访客 我们查看cookie值&…

「Python程序设计」基本数据类型:字符串

​在python的程序设计过程中&#xff0c;字符串是需要经常处理的变量类型。字符串在程序中的存储方式&#xff0c;类似于一维数组&#xff0c;每个字符占据数组中的一个单元格。 字符串可以存储字符类型的变量&#xff0c;即使是数字类型&#xff0c;也可以通过字符串来进行存…

vue3+vite配置环境变量实现开发、测试、生产的区分

文章目录 一、为什么需要区分 (dev)、测试 (test) 和生产 (prod) 环境二、vue3的项目如何通过配置方式区分不同的环境1、创建不同环境的.env文件2、在不同的.env文件中配置相应的环境变量1&#xff09;.env.develoment2&#xff09;.env.test3&#xff09;.env.production 3、在…

Git之git stash高级用法(五十)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列…