【Pytorch+torchvision】MNIST手写数字识别

深度学习入门项目,含代码详细解析

在本文中,我们将在PyTorch中构建一个简单的卷积神经网络,并使用MNIST数据集训练它识别手写数字。 MNIST包含70,000张手写数字图像: 60,000张用于培训,10,000张用于测试。图像是灰度(即通道数为1)28x28像素,并且居中的,以减少预处理和加快运行。

目录

 1.整体代码

 2.代码解析

2.1参数设置

2.2数据集

2.3查看测试数据 

2.4定义卷积神经网络​编辑

2.5初始化网络与优化器

3.实验结果


 1.整体代码

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn #torch.nn层中包含可训练的参数
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
#注意下面两行在matplotlib使用上出错时,加上可不出错
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'n_epochs = 3 #epoch的数量定义了将循环整个训练数据集的次数
batch_size_train = 64 #每次投喂的样本数量
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5 #优化器的超参数
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed) #对于可重复的实验,须为任何使用随机数产生的东西设置随机种子
#训练集数据
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, #加载该数据集(download=True)transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])), #Normalize()转换使用的值0.1307和0.3081是该数据集的全局平均值和标准偏差,这里将它们作为给定值batch_size=batch_size_train, shuffle=True)
#测试集数据
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True) #使用size=1000对这个数据集进行测试
#查看一批测试数据由什么组成
examples = enumerate(test_loader) #enumerate指循环,类似for
batch_idx, (example_data, example_targets) = next(examples) #example_targets是图片实际对应的数字标签,example_data是指图片本身数据
print(example_targets)
print(example_data.shape) #输出torch.Size([1000, 1, 28, 28]),意味着我们有1000个例子的28x28像素的灰度(即没有rgb通道)#定义卷积神经网络
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# batch*1*28*28(每次会送入batch个样本,输入通道数1(黑白图像),图像分辨率是28x28)# 下面的卷积层Conv2d的第一个参数指输入通道数,第二个参数指输出通道数(即用了几个卷积核),第三个参数指卷积核的大小self.conv1 = nn.Conv2d(1, 10, kernel_size=5) #因为图像为黑白的,所以输入通道为1,此时输出数据大小变为28-5+1=24.所以batchx1x28x28 -> batchx10x24x24self.conv2 = nn.Conv2d(10, 20, kernel_size=5) #第一个卷积层的输出通道数等于第二个卷积层是输入通道数。self.conv2_drop = nn.Dropout2d() #在前向传播时,让某个神经元的激活值以一定的概率p停止工作,可以使模型泛化性更强,因为它不会太依赖某些局部的特征self.fc1 = nn.Linear(320, 50) #由于下部分前向传播处理后,输出数据为20x4x4=320,传递给全连接层。# 输入通道数是320,输出通道数是50self.fc2 = nn.Linear(50, 10)#输入通道数是50,输出通道数是10,(即10分类(数字1-9),最后结果需要分类为几个就是几个输出通道数)。全连接层(Linear):y=x乘A的转置+bdef forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2)) # batch*10*24*24 -> batch*10*12*12(2*2的池化层会减半,步长为2)(激活函数ReLU不改变形状)x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #此时输出数据大小变为12-5+1=8(卷积核大小为5)(2*2的池化层会减半)。所以 batchx10x12x12 -> batchx20x4x4。x = x.view(-1, 320) # batch*20*4*4 -> batch*320x = F.relu(self.fc1(x)) #进入全连接层x = F.dropout(x, training=self.training) #减少遇到过拟合问题,dropout层是一个很好的规范模型。x = self.fc2(x)#计算log(softmax(x))return F.log_softmax(x)
#初始化网络和优化器
#如果我们使用GPU进行训练,应使用例如network.cuda()将网络参数发送给GPU。将网络参数传递给优化器之前,将它们传输到适当的设备很重要,否则优化器无法以正确的方式跟踪它们。
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,momentum=momentum)
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
#每个epoch对所有训练数据进行一次迭代。加载单独批次由DataLoader处理
#训练函数
def train(epoch):network.train() #在训练模型时会在前面加上for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad() #使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度output = network(data) #生成网络的输出(前向传递)loss = F.nll_loss(output, target) #计算输出(output)与真值标签(target)之间的负对数概率损失loss.backward() #对损失反向传播optimizer.step() #收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数if batch_idx % log_interval == 0: #log_interval=10,每10次投喂后输出一次print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))train_losses.append(loss.item()) #添加进训练损失列表中train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))#神经网络模块以及优化器能够使用.state_dict()保存和加载它们的内部状态。这样,如果需要,我们就可以继续从以前保存的状态dict中进行训练——只需调用.load_state_dict(state_dict)。torch.save(network.state_dict(), './model.pth')torch.save(optimizer.state_dict(), './optimizer.pth')train(1)#测试函数。总结测试损失,并跟踪正确分类的数字来计算网络的精度。
def test():network.eval() #在测试模型时在前面使用test_loss = 0correct = 0with torch.no_grad(): #使用上下文管理器no_grad(),我们可以避免将生成网络输出的计算结果存储在计算图(计算过程的构建,以便梯度反向传播等操作)中。(with是使用的意思)for data, target in test_loader:output = network(data) #生成网络的输出(前向传递)# 将一批的损失相加test_loss += F.nll_loss(output, target, size_average=False).item() #NLLLoss 的输入是一个对数概率向量和一个目标标签pred = output.data.max(1, keepdim=True)[1] ## 找到概率最大的下标correct += pred.eq(target.data.view_as(pred)).sum() #预测正确的数量相加test_loss /= len(test_loader.dataset)test_losses.append(test_loss)print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))test()#我们将在循环遍历n_epochs之前手动添加test()调用,以使用随机初始化的参数来评估我们的模型。
for epoch in range(1, n_epochs + 1):train(epoch)test()#评估模型的性能,画损失曲线
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()#输出自己找的测试图片,比较模型的输出。
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():output = network(example_data)
fig1 = plt.figure()
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0], cmap='gray', interpolation='none')plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])
plt.show()#继续对网络进行训练,并看看如何从第一次培训运行时保存的state_dicts中继续进行训练。我们将初始化一组新的网络和优化器。
continued_network = Net()
continued_optimizer = optim.SGD(network.parameters(), lr=learning_rate,momentum=momentum)network_state_dict = torch.load('model.pth') #见左侧项目列表,有该文件
continued_network.load_state_dict(network_state_dict) #使用.load_state_dict(),我们现在可以加载网络的内部状态,并在最后一次保存它们时优化它们。
optimizer_state_dict = torch.load('optimizer.pth') #见左侧项目列表,有该文件
continued_optimizer.load_state_dict(optimizer_state_dict)
#同样,运行一个训练循环应该立即恢复我们之前的训练。为了检查这一点,我们只需使用与前面相同的列表来跟踪损失值
for i in range(4,9):test_counter.append(i*len(train_loader.dataset))train(i)test()
#我们再次看到测试集的准确性从一个epoch到另一个epoch有了(运行更慢的,慢的多了)提高。
#输出自己找的测试图片,比较模型的输出。
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():output = network(example_data)
fig1 = plt.figure()
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0], cmap='gray', interpolation='none')plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])
plt.show()

 2.代码解析

2.1参数设置

(1)深度学习中Epoch、Batch以及Batch size的设定 - 知乎 (zhihu.com)

Epoch(时期):将所有训练样本训练一次的过程

Batch:将整个训练样本分为若干个Batch

Batch_Size:每个Batch的样本数量

Iteration:训练一个Batch就是一个Iteration 

(2)学习率一般设置为0.1或0.01

(3)Pytorch——momentum动量_momentum pytorch_Chukai123的博客-CSDN博客

Momentum作用:动量,跳出局部最优解。

引入momentum之后的权重更新:v=momentum∗v−Lr∗dw;w=w+v

V为速度一般初始为0

(4)log_interval=10:间隔10个Batch输出一次

(5)【pytorch】torch.manual_seed()用法详解_torch.seed_Xavier Jiezou的博客-CSDN博客

torch.manual_seed(seed):设置每次运行py文件生成的随机数相同。

2.2数据集

(1)torch.utils.data.DataLoader

Shuffle=True:打乱数据

(2)torchvision.datasets.MNIST

Root:MNIST数据集根目录

Train:true则从training.pt创建数据集,否则从test.pt创建

Download:true则从internet下载放在根目录

Transform:

torchvision.transforms 参数解读/中文使用手册_torchvision.transforms.functional.rotate_江南蜡笔小新的博客-CSDN博客

torchvision.transforms.ToTensor
PIL图片或者numpy.ndarray转成Tensor类型的

torchvision.transforms.functional.normalize(tensor, mean, std)
根据给定的标准差和方差归一化tensor图片
参数:

  • tensor(Tensor)—— 形状为(C,H,W)Tensor图片
  • mean(squence) —— 每个通道的均值,序列
  • std (sequence) —— 每个通道的标准差,序列
    返回:返回归一化后的Tensor图片。

2.3查看测试数据 

Enumerate:将一个可遍历对象组合为一个索引序列

Next:返回迭代器的下一个项目

2.4定义卷积神经网络

Super:调用父类方法

卷积输出大小 = 输入分辨率 – 卷积核大小 + 1

输出通道数 = 使用卷积核数量

第一个全连接层输入分辨率如何确定?

28->24,24/2->12,12->8,8/2->4

这么说可能有些抽象,看下面的图就知道怎么来的了。

F.relu对应右侧图示的激活函数

PyTorch常用激活函数解析_f.leaky_relu_orientliu96的博客-CSDN博客

F.max_pool2d(,2):对卷积层进行最大池化,“2”为步长(2*2的池化层)

x.view:将tensor reshape成一维向量

F.log_softmax:归一化输出

2.5初始化网络与优化器

Optim.SGD:随机梯度下降

[i*len(train_loader.dataset) for i in range(n_epochs + 1)] 使用列表推导式构建一个样本数列表

 F.nll_lossNLLLoss 函数输入 input 之前,需要对 input 进行 log_softmax 处理,即将 input 转换成概率分布的形式,并且取对数,底数为 e。其损失函数为负对数似然。

3.实验结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/82246.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(亲测解决)PyCharm 从目录下导包提示 unresolved reference(完整图解)

最近在进行一个Flask项目的过程中遇到了unresolved reference 包名的问题,在网上找了好久解决方案,并没有一个能让我一步到位解决问题的。 后来,我对该问题和网上的解决方案进行了分析,发现网上大多数都是针对项目同一目录下的py…

PHP Mysql查询全部全部返回字符串类型

设置pdo属性 $pdo->setAttribute(PDO::ATTR_EMULATE_PREPARES, true);

MongoDB 入门

1.1 数据库管理系统 在了解MongoDB之前需要先了解先数据库管理系统 1.1.1 什么是数据? 数据(英语:data),是指未经过处理的原始记录。 一般而言,数据缺乏组织及分类,无法明确的表达事物代表的意…

MySQL 慢查询探究分析

目录 背景: mysql 整体结构: SQL查询语句执行过程是怎样的: 知道了mysql的整体架构,那么一条查询语句是怎么被执行的呢: 什么是索引: 建立索引越多越好吗:   如何发现慢查询&#xff1…

C语言学习笔记 vscode使用外部console-11

前言 在默认情况下,我们运行C语言程序都是在vscode终端的,在小程序运行时这个是没有问题的,但是当程序变得复杂它就不好用了,这时我们可以将这个终端设置为外部console,这样方便处理更多、更复杂的程序。 步骤 1.点击…

SSM——环境搭建、产品操作、订单操作

SSM 环境搭建与产品操作 1. 环境准备 1.1 数据库与表结构 1.1.1 创建用户与授权 数据库我们使用 Oracle Oracle 为每个项目创建单独 user , oracle 数据表存放在表空间下,每个用户有独立表空间 创建用户及密码 语法 [ 创建用户 ] : crea…

后端进阶之路——深入理解Spring Security配置(二)

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 「推荐专栏」: ★java一站式服务 ★ ★前端炫酷代码分享 ★ ★ uniapp-从构建到提升★ ★ 从0到英雄,vue成神之路★ ★ 解决算法,一个专栏就够了★ ★ 架…

根证书和中间证书有什么区别?

通常即使是获取了SSL证书的人,也只知道他们需要SSL证书,而且他们必须在服务器上安装SSL证书,才能通过HTTPS为网站提供服务。当进一步提到中间证书、根证书时,大多数人都感到陌生。本文小编就将为您介绍根证书与中间证书的定义以及…

实力认证!TDengine 入选 Gartner 中国数据分析与人工智能技术成熟度曲线

近日,国际权威研究机构 Gartner 发布了《2023 年中国数据分析及人工智能技术成熟度曲线》(即《Hype Cycle for Data, Analytics and AI in China, 2023》)报告,TDengine 成功入选实时数据管理领域代表产品。 作为评估全球新技术成…

Java使用String来开发验证码

Java使用String来开发验证码 需求分析代码实现小结Time 需求分析 使用String来开发验证码。 实现随机产生验证码,验证码的每位可能是数字、大写字母、小写字母 根据需求分析,步骤如下: 1.首先,设计一个方法,该方法接收…

机器学习复习题

1 单选题 ID3算法、C4.5算法、CART算法都是( )研究方向的算法。 A . 决策树 B. 随机森林 C. 人工神经网络 D. 贝叶斯学习 参考答案:A ( )作为机器学习重要算法之一,是一种利用多个树分类器进行分类和预测…

《Python入门到精通》os模块详解,Python os标准库

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 os模块详解 1、文件目录操作os.stat() 获取文件状态os.utime() 修改文件时间os.r…

textarea 标签如何创建多行文本输入框?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ textarea 的写法⭐ 代码含义⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅!这个专栏是为那些对Web开发感兴趣、…

5G RedCap

5G RedCap指的是3GPP所提出的5G标准。与之前发布的5G标准相比,功能更加精简。5G RedCap于2019年6月首次被纳入3GPP R17研究项目。 把一些不必要的功能去掉就可以带来模组价格的降低。背后的基本想法是:为物联网应用定义一种新的、不那么复杂的NR设备。 …

linux系统虚拟主机开启支持Swoole Loader扩展

特别说明:只是安装支持Swoole扩展,主机并没有安装服务端。目前支持版本php5.4-php7.2。 1、登陆主机控制面板,找到【远程文件下载】这个功能。 2、远程下载文件填写http://download.myhostadmin.net/vps/SwooleLoader_linux.zip 下载保存的路…

React 入门学习

React 入门 一、基本认识1.1、前言1.2、什么是1.3、编译<br>1.4、特点1.5、高效 二、React环境和基本使用2.1、环境搭建2.2、脚手架项目基本使用2.2.1、src2.2.2、public2.2.3、package.json 三、JSX的理解和使用四、模块与模块化, 组件与组件化的理解4.1、模块与组件4.2…

【Matplotlib】一文搞定Matplotlib绘图配置(大三学长的万字笔记)

文章目录 一、Matplotlib介绍1 - 介绍2 - 安装 二、基本配置1 - 中文配置2 - 查看字体库3 - 基本绘图4 - 线样式和颜色 三、画布配置1 - 基本配置2 - 多图绘制 | 同一画布&#xff08;重叠&#xff09;3 - 多图绘制 | 多个画布4 - 多图绘制 | 同一画布&#xff08;子图&#xf…

海外媒体发稿:软文写作方法方式?一篇好的软文理应合理规划?

不同种类的软文会有不同的方式&#xff0c;下面小编就来来给大家分析一下&#xff1a; 方法一、要选定文章的突破点&#xff1a; 所说突破点就是这篇文章文章软文理应以什么样的视角、什么样的见解、什么样的语言设计理念、如何文章文章的标题来写。不同种类的传播效果&#…

【Leetcode】对称二叉树||递归(击败100%)

step by step. 题目&#xff1a; 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 示例 1&#xff1a; 输入&#xff1a;root [4,2,7,1,3,6,9] 输出&#xff1a;[4,7,2,9,6,3,1]示例 2&#xff1a; 输入&#xff1a;root [2,1,3] 输出…

使用node.js 搭建一个简单的HelloWorld Web项目

文档结构 config.ini #将本文件放置于natapp同级目录 程序将读取 [default] 段 #在命令行参数模式如 natapp -authtokenxxx 等相同参数将会覆盖掉此配置 #命令行参数 -config 可以指定任意config.ini文件 [default] authtokencc83c08d73357802 #对应一条隧…