第P2周:Pytorch实现CIFAR10彩色图片识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

  1. 实现CIFAR-10的彩色图片识别
  2. 实现比P1周更复杂一点的CNN网络

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1

(二)具体步骤
1.
import torch  
import torch.nn as nn  
import matplotlib.pyplot as plt  
import torchvision  # 第一步:设置GPU  
def USE_GPU():  if torch.cuda.is_available():  print('CUDA is available, will use GPU')  device = torch.device("cuda")  else:  print('CUDA is not available. Will use CPU')  device = torch.device("cpu")  return device  device = USE_GPU()  

输出:CUDA is available, will use GPU

  
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载  
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=torchvision.transforms.ToTensor())  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,  transform=torchvision.transforms.ToTensor())  batch_size = 32  
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)  # 取一个批次查看数据格式  
# 数据的shape为:[batch_size, channel, height, weight]  
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。  
imgs, labels = next(iter(train_dataload))  
print(imgs.shape)  # 查看一下图片  
import numpy as np  
plt.figure(figsize=(20, 5))  
for i, images in enumerate(imgs[:20]):  # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理  npimg = imgs.numpy().transpose((1, 2, 0))  # 将整个figure分成2行10列,并绘制第i+1个子图  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()  

输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
image.png

# 第三步,构建CNN网络  
import torch.nn.functional as F  num_classes = 10  # 因为CIFAR-10是10种类型  
class Model(nn.Module):  def __init__(self):  super(Model, self).__init__()  # 提取特征网络  self.conv1 = nn.Conv2d(3, 64, 3)  self.pool1 = nn.MaxPool2d(kernel_size=2)  self.conv2 = nn.Conv2d(64, 64, 3)  self.pool2 = nn.MaxPool2d(kernel_size=2)  self.conv3 = nn.Conv2d(64, 128, 3)  self.pool3 = nn.MaxPool2d(kernel_size=2)  # 分类网络  self.fc1 = nn.Linear(512, 256)  self.fc2 = nn.Linear(256, num_classes)  # 前向传播  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = self.pool3(F.relu(self.conv3(x)))  x = torch.flatten(x, 1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  from torchinfo import summary  
# 将模型转移到GPU中  
model = Model().to(device)  
summary(model)  

image.png

# 训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)    # 设置优化器  # 编写训练函数  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片  num_batches = len(dataloader)   # 批次大小,这里是1875(60000/32=1875)  train_acc, train_loss = 0, 0    # 初始化训练正确率和损失率都为0  for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值)  X, y = X.to(device), y.to(device)  # 计算预测误差  pred = model(X) # 网络输出预测值  loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距  # 反向传播  optimizer.zero_grad()   # grad属性归零  loss.backward() # 反向传播  optimizer.step()    # 第一步自动更新  # 记录正确率和损失率  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 测试函数  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片  num_batches = len(dataloader)   # 批次大小 ,这里312,即10000/32=312.5,向上取整  test_acc, test_loss = 0, 0  # 因为是测试,因此不用训练,梯度也不用计算不用更新  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  # 计算loss  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式训练  
epochs = 10  
train_acc, train_loss, test_acc, test_loss = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}'  print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  print('Done')  # 结果可视化  
# 隐藏警告  
import warnings  
warnings.filterwarnings('ignore')   # 忽略警告信息  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False  # 正常显示+/-号  
plt.rcParams['figure.dpi'] = 100    # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  plt.subplot(1, 2, 1)    # 第一张子图  
plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练和测试正确率比较')  plt.subplot(1, 2, 2)    # 第二张子图  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练和测试损失率比较')  plt.show()# 保存模型  
torch.save(model, './models/cnn-cifar10.pth')

image.png
再次设置epochs为50训练结果:
image.png
epochs增加到100,训练结果:
image.png
可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:

data_transforms= {  'train': transforms.Compose([  transforms.RandomHorizontalFlip(),  transforms.ToTensor(),  ]),  'test': transforms.Compose([  transforms.ToTensor(),  ])  
}

在dataset中:

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=data_transforms['train'])  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])

运行结果:
image.png
image.png
比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。
image.png
batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。
image.png

(三)总结
  1. epochs并不是越多越好。batch_size同样的道理
  2. 数据增强确实可以提高模型训练的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/491683.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Django基础 - 01入门简介

一、 基本概念 1.1 Django说明 Django发布于2005年, 网络框架, 用Python编写的开源的Web应用框架。采用了MVC框架模式,也称为MTV模式。官网: https://www.djangoproject.com1.2 MVC框架 Model: 封装和数据库相关…

什么是事务?隔离级别

一、什么是事务? 事务是代表单个工作单元的一组SQL语句,它确保这些语句要么全部成功,要么全部失败回滚。(想象一下,你去银行转账。你要把一笔钱从一个账户(账户A)转到另一个账户(账…

[LeetCode-Python版]206. 反转链表(迭代+递归两种解法)

题目 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3&#xff1…

2025山东科技大学考研专业课复习资料一览

[冲刺]2025年山东科技大学020200应用经济学《814经济学之西方经济学[宏观部分]》考研学霸狂刷870题[简答论述计算题]1小时前[强化]2025年山东科技大学085600材料与化工《817物理化学》考研强化检测5套卷22小时前[冲刺]2025年山东科技大学030100法学《704综合一[法理学、国际法学…

kubernetes学习-使用metrics-server监控集群资源和查看日志

kubernetes学习-使用metrics-server监控集群资源和查看日志 一 、简介 Metrics Server 是一个用于 Kubernetes 集群的监控工具,它用于收集、存储和提供关于集群中各种资源的度量数据。Metrics Server 是 Kubernetes 中一个核心的指标收集器,可以提供关…

如何使用git新建本地仓库并关联远程仓库的步骤(详细易懂)

一、新建本地仓库并关联远程仓库的步骤 新建本地仓库 打开终端(在 Windows 上是命令提示符或 PowerShell,在 Linux 和Mac上是终端应用),进入你想要创建仓库的目录。例如,如果你想在桌面上创建一个名为 “my - project”…

【小沐学GIS】基于C++绘制三维数字地球Earth(OpenGL、glfw、glut、QT)第三期

🍺三维数字地球系列相关文章如下🍺:1【小沐学GIS】基于C绘制三维数字地球Earth(456:OpenGL、glfw、glut)第一期2【小沐学GIS】基于C绘制三维数字地球Earth(456:OpenGL、glfw、glut)第二期3【小沐…

【机器学习】在向量的流光中,揽数理星河为衣,以线性代数为钥,轻启机器学习黎明的瑰丽诗章

文章目录 线性代数入门:机器学习零基础小白指南前言一、向量:数据的基本单元1.1 什么是向量?1.1.1 举个例子: 1.2 向量的表示与维度1.2.1 向量的维度1.2.2 向量的表示方法 1.3 向量的基本运算1.3.1 向量加法1.3.2 向量的数乘1.3.3…

《拉依达的嵌入式\驱动面试宝典》—C/CPP基础篇(一)

《拉依达的嵌入式\驱动面试宝典》—C/CPP基础篇(一) 你好,我是拉依达。 感谢所有阅读关注我的同学支持,目前博客累计阅读 27w,关注1.5w人。其中博客《最全Linux驱动开发全流程详细解析(持续更新)-CSDN博客》已经是 Lin…

【Linux】—简单实现一个shell(myshell)

大家好呀,我是残念,希望在你看完之后,能对你有所帮助,有什么不足请指正!共同学习交流哦! 本文由:残念ing原创CSDN首发,如需要转载请通知 个人主页:残念ing-CSDN博客&…

数据结构day5:单向循环链表 代码作业

一、loopLink.h #ifndef __LOOPLINK_H__ #define __LOOPLINK_H__#include <stdio.h> #include <stdlib.h>typedef int DataType;typedef struct node {union{int len;DataType data;};struct node* next; }loopLink, *loopLinkPtr;//创建 loopLinkPtr create();//…

利用CNN与多尺度特征、注意力机制的融合实现低分辨率人脸表情识别,并给出模型介绍与代码实现

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下利用CNN与多尺度特征、注意力机制的融合实现低分辨率人脸表情识别&#xff0c;并给出模型介绍与代码实现。在当今社会&#xff0c;人脸识别技术已广泛应用&#xff0c;但特定场景下的低质量图像仍是一大挑战。 低分…

Scala—“==“和“equals“用法(附与Java对比)

Scala 字符串比较—""和"equals"用法 Scala 的 在 Scala 中&#xff0c; 是一个方法调用&#xff0c;实际上等价于调用 equals 方法。不仅适用于字符串&#xff0c;还可以用于任何类型&#xff0c;并且自动处理 null。 Demo&#xff1a; Java 的 在 J…

驱动开发-入门【1】

1.内核下载地址 Linux内核源码的官方网站为https://www.kernel.org/&#xff0c;可以在该网站下载最新的Linux内核源码。进入该网站之后如下图所示&#xff1a; 从上图可以看到多个版本的内核分支&#xff0c;分别为主线版本&#xff08;mainline&#xff09;、稳定版本&#…

Android Room 数据库使用详解

一、Room介绍 Android Room 是 Google 提供的一个 Android 数据持久化库&#xff0c;是 Android Jetpack 组成部分之一。它提供了一个抽象层&#xff0c;使得 SQLite 数据库的使用更为便捷。通过 Room&#xff0c;开发者可以轻松地操作数据库&#xff0c;不需要直接编写繁琐的…

使用k6进行Redis基准测试

1.安装环境 前提条件&#xff1a;已经安装go 安装xk6 go install go.k6.io/xk6/cmd/xk6latest 安装成功会在GOPATH目录生成xk6可执行文件 安装xk6-redis 切换到xk6工作目录&#xff0c;执行如下命令 cd /Users/wan/go/bin ./xk6 build --with github.com/grafana/xk6-re…

【期末复习】JavaEE(上)

1. Java EE概述 开发环境及开发工具 1.1. HTTP协议 开发模式 2. Java Web技术 JSP技术 2.1. Servlet技术 2.1.1. HttpServletRequest 常用方法 2.1.2. HttpServletRequest 请求乱码 tomcat7 及以下&#xff08;对于每个参数单独进行编码转换&#xff09;&#xff1a; 2.…

KeyFormer:使用注意力分数压缩KV缓存

Keyformer: KV Cache Reduction through Key Tokens Selection for Efficient Generative Inference 202403&#xff0c;发表在Mlsys Introduction 优化KV cache的策略&#xff0c;主要是集中在系统级别的优化上&#xff0c;比如FlashAttention、PagedAttention&#xff0c;它…

Win11安装安卓子系统WSA

文章目录 简介一、启用Hyper-V二、安装WSA三、安装APKAPK商店参考文献 简介 WSA&#xff1a;Windows Subsystem For Android 一、启用Hyper-V 控制面板 → 程序和功能 → 启用或关闭 Windows 功能 → 勾选 Hyper-V 二、安装WSA 进入 Microsoft Store&#xff0c;下拉框改为 …

PHP排序算法:数组内有A~E,A移到C或者C移到B后排序,还按原顺序排序,循环

效果 PHP代码 public function demo($params){function moveNext($arr){$length count($arr);$lastElement $arr[$length - 1];for ($i $length - 1; $i > 0; $i--) {$arr[$i] $arr[$i - 1];}$arr[0] $lastElement;return $arr;}function moveAndReplace($array, $from…