CNN卷积网络实现MNIST数据集手写数字识别

步骤一:加载MNIST数据集

train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)
# 测试数据集
test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor())
test_loader = DataLoader(test_data,shuffle=False,batch_size=64)

首先,通过MNIST类创建了train_data对象,指定了数据集的路径root='./data',并且将数据集标记为训练集train=Truedownload=False表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()将数据转换为张量形式。

然后,通过DataLoader类创建了train_loader对象,指定了使用train_data作为数据源。shuffle=True表示在每个epoch开始时,将数据打乱顺序。batch_size=64表示每次抓取64个样本。

接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader。不同的是,这里将数据集标记为测试集train=False,并且shuffle=False表示不需要打乱顺序。

加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:

其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。

存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:

"""
Created on Sat Jul 27 15:26:39 2024@author: wangyiyuan
"""
# 导入包
import struct
import numpy as np
from PIL import Imageclass MnistParser:# 加载图像def load_image(self, file_path):# 读取二进制数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>iiii'offset = 0# 读取头文件magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))# 处理数据image_size = rows_number * columns_numberfmt_data = '>'+str(image_size)+'B'offset = offset + struct.calcsize(fmt_head)# 读取数据images = np.empty((images_number,rows_number,columns_number))for i in range(images_number):images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1) % 10000 == 0:print('> 已读取:%d张图片'%(i+1))# 返回数据return images_number,rows_number,columns_number,images# 加载标签def load_labels(self, file_path):# 读取数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>ii'offset = 0# 读取头文件magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('标签数:%d'%(items_number))# 处理数据fmt_data = '>B'offset = offset + struct.calcsize(fmt_head)# 读取数据labels = np.empty((items_number))for i in range(items_number):labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1)%10000 == 0:print('> 已读取:%d个标签'%(i+1))# 返回数据return items_number,labels# 图片可视化def visualaztion(self, images, labels, path):d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}for i in range(images.__len__()):im = Image.fromarray(np.uint8(images[i]))im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))d[labels[i]] += 1# im.show()if (i+1)%10000 == 0:print('> 已保存:%d个图片'%(i+1))# 保存为图片格式
def change_and_save():mnist =  MnistParser()trainImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(trainImageFile)trainLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(trainLabelFile)mnist.visualaztion(images, labels, "./images/train/")testImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(testImageFile)testLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(testLabelFile)mnist.visualaztion(images, labels, "./images/test/")# 测试
if __name__ == '__main__':change_and_save()

将这个1.py文件和下载好的数据集放在同一个文件夹下:

新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。

运行完可以发现train和test里的内容如下:

步骤二:建立模型

class Model(nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = nn.Linear(784,256)self.linear2 = nn.Linear(256,64)self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出def forward(self,x):x = x.view(-1,784) # 变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# x = torch.relu(self.linear3(x))return x

这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。

步骤三:训练模型

model = Model()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss
optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率# 模型训练
# def train():
for index,data in enumerate(train_loader):input,target = data # input为输入数据,target为标签optimizer.zero_grad() # 梯度清零y_predict = model(input) # 模型预测loss = criterion(y_predict,target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数if index % 100 == 0: # 每一百次保存一次模型,打印损失torch.save(model.state_dict(),"./model/model.pkl") # 保存模型torch.save(optimizer.state_dict(),"./model/optimizer.pkl")print("损失值为:%.2f" % loss.item())

首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。

步骤四:保存模型参数

if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数

在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。

步骤五:检验模型

correct = 0 # 正确预测的个数total = 0 # 总数with torch.no_grad(): # 测试不用计算梯度for data in test_loader:input,target = dataoutput=model(input) # output输出10个预测取值,其中最大的即为预测的数probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数print("准确率为:%.2f" % (correct / total))

参数说明:

  • correct:记录正确预测的个数
  • total:记录总样本数
  • test_loader:测试集的数据加载器
  • input:输入数据
  • target:目标标签
  • output:模型的输出结果
  • probability:最大概率值
  • predict:最大值的下标

过程:

  • 使用torch.no_grad()包装测试过程,表示不需要计算梯度
  • 遍历测试集中的每个数据,获取输入数据和目标标签
  • 将输入数据输入模型,得到模型的输出结果
  • 使用torch.max()函数返回预测结果中的最大概率值和最大值的下标
  • 更新总数和正确预测的个数
  • 最后计算并输出准确率。

步骤六:检测自己的手写数据

if __name__ == '__main__':# 自定义测试image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片image = image.resize((28,28)) # 裁剪尺寸为28*28image = image.convert('L') # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1,1,28,28)output = model(image)probability,predict=torch.max(output.data,dim=1)print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")plt.imshow(image.squeeze())plt.show()

这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:

  1. 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
  2. 调整图片尺寸为28x28。
  3. 将图片转换为灰度图像,以便后续处理。
  4. 使用transforms.ToTensor()将图片转换为PyTorch张量。
  5. 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
  6. 将处理后的图片输入模型,获取预测输出。
  7. 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
  8. 打印预测的数字和概率。
  9. 在图像上显示预测结果和手写图片。
  10. 展示图像。

步骤七:结果展示

我的原图是:

测试得到的结果为:


损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/389272.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GBase8c psycopg2安装(centos6)

GBase8c psycopg2安装(centos6) 安装步骤: [rootcentos6 ~]# cd /opt/python/ [rootcentos6 python]# ls psycopg2-2.7.7.tar.gz [rootcentos6 python]# tar -zxf psycopg2-2.7.7.tar.gz [rootcentos6 python]# cd psycopg2-2.7.7 # 安装命令 [rootcentos6 psycop…

B站安全开发流程落地实践

一. 什么是安全开发生命周期(SDL) 1.1 SDL诞生背景 随着互联网技术的快速发展,网络系统及应用在给人们的生活带来巨大便利的同时,信息安全问题也逐渐成为用户和企业关注的焦点。然而,安全问题的管理和解决需要一个系统…

武汉流星汇聚:亚马逊Prime会员日后,确保持续稳定出单的五大策略

随着亚马逊Prime会员日的圆满落幕,无数商家沉浸在销售高峰的喜悦之中,但狂欢之后的冷静思考同样重要。对于所有卖家而言,如何在会员日热潮退去后,依然保持稳定的订单量,成为关乎长远发展的关键。以下,武汉流…

MySQL数据库入门基础知识 【1】推荐

数据库就是储存和管理数据的仓库,对数据进行增删改查操作,其本质是一个软件。 首先数据有两种,一种是关系型数据库,另一种是非关系型数据库。 关系型数据库是以表的形式来存储数据,表和表之间可以有很多复杂的关系&a…

nova7(华为)相机关闭画质优化

模板 文章目录 模板 如果对你有帮助,就点赞收藏把!(。・ω・。)ノ♡ 不知道大家有没有遇到这种苦恼 想拍一张,夜景照片 明明按下快门的时候还是如上图所示 但是到图库就只能看到下图的照片…

多路径 bbr mpbbr 公平性推演

mptcp 推出很久了,先看 rfc6356 三原则: 对自己,mptcp 的吞吐不能比用 sp(single path)tcp 时更差;对它者,mptcp 子流对资源的占用不能侵害其它 sptcp 流量;负载分担,要将孬 subflow 流量分担到…

SX_初识GitLab_1

1、对GitLab的理解: 目前对GitLab的理解是其本质是一个远程代码托管平台,上面托管多个项目,每个项目都有一个master主分支和若干其他分支,远程代码能下载到本机,本机代码也能上传到远程平台 1.分支的作用&#xff1a…

20.rabbitmq插件实现延迟队列

问题 前面谈到基于死信的延迟队列,存在的问题:如果第一个消息延时时间很长,而第二个消息延时时间很短,第二个消息并不会优先得到执行。 下载插件 地址:https://github.com/rabbitmq/rabbitmq-delayed-message-excha…

JAVA基础 - 反射

目录 一. 简介 二. java.lang.Class类 三. java.lang.reflect包 四. 创建对象 五. 调用方法 六. 调用成员变量 一. 简介 反射是 Java 语言中的一种强大机制,允许程序在运行时动态地获取类的信息、访问类的成员(包括字段、方法和构造函数&#xff…

Tomato靶机攻略

1、启动靶机 2、通过nmap -sA 192.168.168.0/24得到靶机IP 3、扫描目录 用dirb http://192.168.49.128扫描敏感目录 4、访问敏感目录 5、通过查看源码,发现其存在文件包含漏洞,利用该漏洞查看日志文件 http://192.168.168.131/antibot_image/antibots/…

gitee的fork

通过fork操作,可以复制小组队长的库。通过复制出一模一样的库,先在自己的库修改,最后提交给队长,队长审核通过就可以把你做的那一份也添加入库 在这fork复制一份到你自己的仓库,一般和这个项目同名 现在你有了自己的库…

vue2以及vue3基于el-table实现表格正则校验功能

常见需求: 在项目中,通常会在表格中添加多条数据,并需要对添加的数据进行校验功能,这时候就是很头疼的事了,下面酱酱仔给你们写个示例,你们无脑粘贴复制即可。 注意事项: 1、校验里面用到了正…

【Unity】3D功能开发入门系列(一)

Unity3D功能开发入门系列(一) 一、开发环境(一)安装 Unity(二)创建项目(三)Unity 窗口布局 二、场景与视图(一)场景(二)游戏物体&…

前端日历插件VCalendar

官网地址 API | VCalendar 1.安装 yarn add v-calendarnext popperjs/core 2.全局引入 mian.js //日历插件 import VCalendar from v-calendar; import v-calendar/style.css;app.use(VCalendar); 3.使用 <div><VCalendar reservationTime expanded borderless…

java各种锁有什么区别

Java 虚拟机&#xff08;JVM&#xff09;中有几种不同类型的锁&#xff0c;每种锁都有其特定的用途和性能特点。下面我将为你介绍几种常见的锁&#xff1a; 1.独占锁&#xff08;也称为悲观锁&#xff09;&#xff1a; 1.synchronized&#xff1a;这是 Java 提供的一种内置的独…

股指期货的套利策略存在哪些风险?

股指期货套利的交易策略。它能够纠正市场上不合理的价格偏差&#xff0c;将价格拉回到正常的轨道。套利交易以其稳健的收益吸引着投资者&#xff0c;但同时也容易让人陷入一个误区——认为套利是无风险的。实际上&#xff0c;套利同样存在风险&#xff0c;只是相对于纯粹的投机…

问题易如反掌?5个常用的AI人工智能助手推荐

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 如今的人工智能技术正以惊人的速度改变着我们的生活方式和工作方式。作为这一变革的关键驱动力&#xff0c;人工智能不仅在科技…

短剧CPS分销系统框架+资源对接是怎么对接的?

目录 前言&#xff1a; 一、前端uniapp内容有什么&#xff1f; 二、后台管理 三、搭建CPS需要准备什么&#xff1f; 总结&#xff1a; 前言&#xff1a; 目前短剧目前在国内是非常的热门&#xff0c;观看的人群非常的多。如果希望能够通过推广短剧来做副业的话&#xff0c…

深入理解PreparedStatement

预处理 Overridepublic boolean login(String username, String userpwd) {Connection con DBUtils.getConnection();try {if(con ! null){PreparedStatement pstmt con.prepareStatement("select username,userpwd from " " users where username? and us…