卷积神经网络实现MNIST手写数字识别 - P1

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 引用需要的包
      • 设置GPU
    • 数据准备
      • 下载数据集
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

引用需要的包

Python写程序都需要做的事

import torch # 有些API直接在模块下
import torch.nn as nn # 大部分和模型相关的API
import torch.optim as optim # 优化器相关API
# 一些可以直接调用的函数封装(和nn下的很多方法是一样的效果不同的形式)
import torch.nn.functional as F from torch.utils.data import DataLoader # 数据集做分批,随机排序
from torchvision import datasets, transforms # 预置数据集下载,数据增强import matplotlib.pyplot as plt # 图表库
import numpy as np # 用来操作numpy数组,图像展示用from torchinfo import summary # 打开模型结构

设置GPU

首先用一个全局的对象设置一下当前的设备,是使用CPU还是CPU

# 有显卡就用显卡,没有就用CPU
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')

数据准备

下载数据集

调用torchvision包预置的API可以一键下载MNIST数据集

train_dataset = datasets.MNIST(root='data',  # 数据存放位置train=True, # 加载训练集还是验证集download=True,  # 本地没有是否从远程下载transform=transforms.ToTensor()) # 载入后将图像转换成pytorch的tensor对象
test_dataset = datasets.MNIST(root='data',  train=False,  # False说明是验证集download=True,transform=transforms.ToTensor())

数据集预览

先看看数据集中图像的样子,比如是单通道还是三通道,长宽是多少,然后就可以设置缩放以及模型的一些参数

image, label = train_dataset[0]
image.shape

图片信息
结果表明数据集中的图片应该是单通道的高28宽28的图像

打印里面20个图看看是什么样的

plt.figure(figsize=(20, 4)) # 设置一个plt图表画板的宽和高,单位是英寸。。
for i in range(20):image, label = train_dataset[i]plt.subplot(2, 10, i+1) # 以2行10列的形式展示图片# 先把tensor转为了numpy数组,然后把(1, 28, 28)第0维用squeeze去掉# cmap=plt.cm.binary说明是一个单通道的灰度图plt.imshow(np.squeeze(image.numpy()), cmap=plt.cm.binary)plt.title(label) # 打印一下对应的标签plt.axis('off') # 不显示坐标轴

图像预览

数据集准备

设置一下数据的批次大小

batch_size = 32
# 训练集上将数据的顺序打乱一下
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
test_loader= DataLoader(test_dataset, batch_size=batch_size)

模型设计

采用一个类似于LeNet的小型卷积网络

class Model(nn.Module):def __init__(self, num_classes):super().__init__()# 定义两个卷积层,核都是3x3的,通道数递增self.conv1 = nn.Conv2d(1, 16, kernel_size=3)self.conv2 = nn.Conv2d(16, 32, kernel_size=3)# 池化层没有参数需要学习,可以复用一个self.maxpool = nn.MaxPool2d(2)# 全连接层的输入维度要结果计算,可以在forward的时候算一下self.fc1 = nn.Linear(5*5*32128)# 最后一层的输出得是分类的数量self.fc2 = nn.Linear(128, num_classes)def forward(self, x):# 28x28 -> conv1 -> 26x26 -> maxpool -> 13x13x = self.maxpool(F.relu(self.conv1(x)))# 13x13 -> conv2 -> 11x11 -> maxpool -> 5x5x = self.maxpool(F.relu(self.conv2(x)))# 这里要进全连接层了,需要把数据压平,保留第0维,从第1维开始压x = torch.Flatten(start_dim=1)x = F.relu(self.fc1(1))# 最后一层就不加激活函数了x = self.fc2()
# 将模型创建后,设备设置为上面定义的设备对象
model = Model(10).to(device)
# 一定要加input_size,不然打印的就不是实际执行的样子,而是按self中定义的顺序,复用的组件也展示不出来
summary(model, input_size(1, 1, 28, 28))

模型结构

模型训练

接下来就到了训练模型的环节了

超参数设置

需要设置的超参数有训练的轮次epoch和学习率learning_rate

# 轮次
epochs = 10
# 学习率
larning_rate = 0.001
# 创建优化器,将模型参数进去,并设置学习率
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 分类问题,无脑使用交叉熵损失
loss_fn = nn.CrossEntropyLoss()

helper函数

编写两个函数用来封装模型训练和模型验证的过程

  1. 模型训练
def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset) # 训练总数据量num_batches = len(train_loader) # 批次数量train_loss, train_acc = 0, 0 # 记录并返回本次训练过程的状态数据for x, y in train_loader:x, y = x.to(device), y.to(device) # 将数据加载到和模型相同的设备中,不然取不到值preds = model(x) # 这样模型会自动调用forward并进行一些参数的跟踪操作等loss = loss_fn(preds, y) # 计算当前批次的损失optimizer.zero_grad() # 清空之前训练时产生的梯度loss.backward() # 在损失函数上对参数执行反向传播计算梯度optimizer.step() # 执行参数更新操作# 累加当前数据train_loss += loss.item()# 计算正确数需要使用argmax求概率最大的一个分类然后和ground truth比较train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batches # 因为一个批次只计算一次损失,求平均值train_acc /= size # 正确率是在总数上计算的return train_loss, train_loss # 返回数据
  1. 模型验证
# 基本上就是train函数的简化
def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)test_loss += loss.item()test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc

正式训练

开始正式训练,其实也可以封装成一个helper

# 记录训练过程的数据
train_loss, train_acc = [],[]
test_loss, test_acc = [],[]for epoch in range(epochs):model.train() # 切换模型为训练模式epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval() # 切换模型为评估模式epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)# 记录本轮次数据train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 打印本轮次的数据信息print(f"Epoch:{epoch+1}, Train loss: {epoch_train_loss:.3f}, Train accuracy: {epoch_train_loss*100:.1f}, Validation loss: {epoch_test_loss:.3f}, Validation accuracy: {epoch_test_acc*100:.1f}")

训练过程

结果呈现

上面打印的结果不够直观我们可以用折线图打印一下

plt.figure(figsize=(16, 4))
series = range(epochs)
plt.subplot(1, 2, 1) # 一排两个图表
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练结果


总结与心得体会

通过整个过程可以发现,手写数字的识别还是非常简单的,训练的效率比较快,结果也不错。非常适合拿来练手,学习一些基本概念、深度学习框架和分类任务实践过程等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/82857.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SPM(Swift Package Manager)开发及常见事项

SPM怎么使用的不再赘述,其优点是Cocoapods这样的远古产物难以望其项背的,而且最重要的是可二进制化、对xcproj项目无侵入,除了网络之外简直就是为团队开发的项目库依赖最好的管理工具,是时候抛弃繁杂低下的cocoapods了。 一&…

C语言:打开调用堆栈

第一步:打断点 第二步:FnF5 第三步:按如图找到调用堆栈

使用Flask.Request的方法和属性,获取get和post请求参数(二)

1、Flask中的request 在Python发送Post、Get等请求时,我们使用到requests库。Flask中有一个request库,有其特有的一些方法和属性,注意跟requests不是同一个。 2、Post请求:request.get_data() 用于服务端获取客户端请求数据。注…

积累常见的有针对性的python面试题---python面试题001

1.考点列表的.remove方法的参数是传入的对应的元素的值,而不是下标 然后再看remove这里,注意这个是,删除写的那个值,比如这里写3,就是删除3, 而不是下标. remove不是下标删除,而是内容删除. 2.元组操作,元组不支持修改,某个下标的内容 可以问他如何修改元组的某个元素 3.…

【MMU】认识 MMU 及内存映射的流程

MMU(Memory Manager Unit),是内存管理单元,负责将虚拟地址转换成物理地址。除此之外,MMU 实现了内存保护,进程无法直接访问物理内存,防止内存数据被随意篡改。 目录 一、内存管理体系结构 1、…

idea打开多个项目需要开多个窗口(恢复询问弹窗)

【版权所有,文章允许转载,但须以链接方式注明源地址,否则追究法律责任】【创作不易,点个赞就是对我最大的支持】 前言 仅作为学习笔记,供大家参考 总结的不错的话,记得点赞收藏关注哦! 使用…

【TypeScript】中定义与使用 Class 类的解读理解

目录 类的概念类的继承 :类的存取器:类的静态方法与静态属性:类的修饰符:参数属性:抽象类:类的类型: 总结: 类的概念 类是用于创建对象的模板。他们用代码封装数据以处理该数据。JavaScript 中的…

一起学SF框架系列7.1-spring-AOP-基础知识

AOP(Aspect-oriented Programming-面向切面编程)是一种编程模式,是对OOP(Object-oriented Programming-面向对象编程)一种有益补充。在OOP中,万事万物都是独立的对象,对象相互耦合关系是基于业务进行的;但在…

MySQL之深入InnoDB存储引擎——Undo页

文章目录 一、UNDO日志格式1、INSERT操作对应的UNDO日志2、DELETE操作对应的undo日志3、UPDATE操作对应的undo日志1)不更新主键2)更新主键的操作 3、增删改操作对二级索引的影响 二、UNDO页三、UNDO页面链表四、undo日志具体写入过程五、回滚段1、回滚段…

C语言系列之原码、反码和补码

一.欢迎来到我的酒馆 讨论c语言中,原码、反码、补码。 目录 一.欢迎来到我的酒馆二.原码 二.原码 2.1在计算机中,所有数据都是以二进制存储的,但不是直接存储二进制数,而是存储二进制的补码。原码很好理解,就是对应的…

SQL Server数据库如何添加Oracle链接服务器(Windows系统)

SQL Server数据库如何添加Oracle链接服务器 一、在添加访问Oracle的组件1.1 下载Oracle的组件 Oracle Provider for OLE DB1.2 注册该组件1.2.1 下载的压缩包解压位置1.2.2 接着用管理员运行Cmd 此处一定要用管理员运行,否则会报错 二、配置环境变量三、 重启SQL Se…

IDEA开启并配置services窗口

一、选择view -> Tool Windows -> Services 二、底下栏会出现Services 然后右键添加工程即可

Apache DolphinScheduler 3.1.8 版本发布,修复 SeaTunnel 相关 Bug

近日,Apache DolphinScheduler 发布了 3.1.8 版本。此版本主要基于 3.1.7 版本进行了 bug 修复,共计修复 16 个 bug, 1 个 doc, 2 个 chore。 其中修复了以下几个较为重要的问题: 修复在构建 SeaTunnel 任务节点的参数时错误的判断条件修复 …

【学习笔记】Java安全之反序列化

文章目录 反序列化方法的对比PHP的反序列化Java的反序列化Python反序列化 URLDNS链利用链分析触发DNS请求 CommonCollections1利用链利用TransformedMap构造POC利用LazyMap构造POCCommonsCollections6 利用链 最近在学习Phith0n师傅的知识星球的Java安全漫谈系列,随…

【多音音频测试信号】具有指定采样率和样本数的多音信号,生成多音信号的相位降低波峰因数研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【torch.nn.PixelShuffle】和 【torch.nn.UnpixelShuffle】

文章目录 torch.nn.PixelShuffle直观解释官方文档 torch.nn.PixelUnshuffle直观解释官方文档 torch.nn.PixelShuffle 直观解释 PixelShuffle是一种上采样方法,它将形状为 ( ∗ , C r 2 , H , W ) (∗, C\times r^2, H, W) (∗,Cr2,H,W)的张量重新排列转换为形状为…

TechTool Pro for mac(硬件监测和系统维护工具)

TechTool Pro 是为 Mac OS X 重新设计的全新工具程序,不但保留旧版原有的硬件侦测功能,还可检查系统上其他重要功能,如:网络连接,区域网络等。 TechTool Pro for mac随时监控和保护您的电脑,并可预设定期检…

机器学习深度学习——非NVIDIA显卡怎么做深度学习(坑点排查)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——数值稳定性和模型化参数(详细数学推导) 📚订阅专栏:机器…

mac安装vscode 配置git

1、安装vscode 官网地址 下载mac稳定版安装很慢的解决办法 (转自) mac电脑如何解决下载vscode慢的问题 选择谷歌浏览器右上角的3个点,选择下载内容,右键选择复制链接地址,在新窗口粘贴地址, 把地址中的一段替换成下面的cscode.sd…

电脑文件丢失如何找回?使用这个方法轻松找回!

电脑文件丢失怎么办?有没有免费的电脑文件恢复软件?相信很多人在日常办公中也都经常会遇到这种现象,不管是在学习中,还是日常的办公,往往也都会在电脑上存储大量的数据文件,那么如果我们在日常办公操作过程…