手写数据集minist基于pytorch分类学习

1.Mnist数据集介绍
1.1 基本介绍
Mnist数据集可以算是学习深度学习最常用到的了。这个数据集包含70000张手写数字图片,分别是60000张训练图片和10000张测试图片,训练集由来自250个不同人手写的数字构成,一般来自高中生,一半来自工作人员,测试集(test set)也是同样比例的手写数字数据,并且保证了测试集和训练集的作者不同。每个图片都是2828个像素点,数据集会把一张图片的数据转成一个2828=784的一维向量存储起来。
里面的图片数据如下所示,每张图是0-9的手写数字黑底白字的图片,存储时,黑色用0表示,白色用0-1的浮点数表示。


1.2 数据集下载
1)官网下载
Mnist数据集的下载地址如下:http://yann.lecun.com/exdb/mnist/
打开后会有四个文件:


训练数据集:train-images-idx3-ubyte.gz
训练数据集标签:train-labels-idx1-ubyte.gz
测试数据集:t10k-images-idx3-ubyte.gz
测试数据集标签:t10k-labels-idx1-ubyte.gz
将这四个文件下载后放置到需要用的文件夹下即可不要解压!下载后是什么就怎么放!

2)代码导入
文件夹下运行下面的代码,即可自动检测数据集是否存在,若没有会自动进行下载,下载后在这一路径:

下载数据集:

# 下载数据集
from torchvision import datasets, transformstrain_set = datasets.MNIST("data",train=True,download=True, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("data",train=False,download=True, transform=transforms.ToTensor(),)

参数解释:

datasets.MNIST:是Pytorch的内置函数torchvision.datasets.MNIST,可以导入数据集
train=True :读入的数据作为训练集
transform:读入我们自己定义的数据预处理操作
download=True:当我们的根目录(root)下没有数据集时,便自动下载
如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:原文件夹/data/MNIST/raw

14轮左右,模型识别准确率达到98%以上

 

 加载数据集

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

构建模型,模型主要由两个卷积层,两个池化层,以及一个全连接层构成,激活函数使用relu. 

 

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e

模型训练
每次训练完成后会自动保存参数到pkl模型中,如果路径中有Pkl文件,下次运行会自动加载上一次的模型参数,在这个基础上继续训练,第一次运行时没有模型参数,结束后会自动生成。

# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))

加载模型
第一次运行这里需要一个空的model文件夹

if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))

模型测试

def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))

自己手写数字图片识别函数(可选用)
这部分主要是加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()

MNIST中的数据识别测试数据
训练过程中的打印信息我进行了修改,这里设置的训练轮数是15轮,每次训练生成的pkl模型参数也是会更新的,想要更多训练信息可以查看对应的教程哦~

if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

完整代码:

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/339729.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux中最基础使用的命令

小白学习记录: 前情提要:Linux命令基础格式!查看 ls看目录的小技巧 进入指定目录 cd查看当前工作目录 pwd创建一个新的目录(文件夹) mkdir创建文件 touch查看文件内容 cat、more操作文件、文件夹- 复制 cp- 移动 mv- 删除【危险操作&#xff…

OpenAI 近期动荡:解雇 Sam Altman 事件分析与 AI 未来展望

引言 OpenAI 的动荡从未停止。最近,由于 OpenAI 高层领导的更迭,引发了广泛的关注和讨论。特别是在 Sam Altman 被解雇后,再次回归 CEO 职位的过程,更是引起了公众和业内的巨大反响。前 OpenAI 董事会成员 Helen Toner 在最新一期…

Java1.8+ idea hbuilder+ uniapp、vue上门家政小程序APP源码开发

Java1.8 idea hbuilder uniapp、vue上门家政小程序APP源码开发 家政服务系统是一种专为家庭提供全方位服务的综合性系统。该系统通过整合多种服务功能和智能化管理,旨在提高家庭生活的质量和效率。 家政服务系统技术开发环境: 技术架构:spri…

心动(GDI+)

文章目录 前言实现步骤源代码心形坐标类心形函数定时器方法绘制函数完整源码 结束语 前言 近期学习了一段时间的GDI,突然想着用GDI绘制点啥,用来验证下类与方法。有兴趣的,可以查阅Windows GDI学习笔记相关文章。 效果展示 实现步骤 定义心形函数 。…

【微服务】部署mysql集群,主从复制,读写分离

两台服务器做如下操作 1.安装mysqldocker pull mysql:5.72.启动以及数据挂载 mkdir /root/mysql/data /root/mysql/log /root/mysql/conf touch my.conf //mysql的配置文件docker run --name mysql \ -e MYSQL_ROOT_PASSWORD123456 \ -v /root/mysql/data:/var/lib/mysql \ -v…

HarmonyOS鸿蒙学习笔记(28)@entry和@Component的生命周期

entry和Component的生命周期 entry和Component的关系Component生命周期Entry生命周期 生命周期流程图生命周期展示示例代码参考资料 HarmonyOS的生命周期可以分为Compnent的生命周期和Entry的生命周期,也就是自定义组件的生命周期和页面的生命周期。 entry和Compone…

RabbitMQ-发布/订阅模式

RabbitMQ-默认读、写方式介绍 RabbitMQ-直连交换机(direct)使用方法 目录 1、发布/订阅模式介绍 2、交换机(exchange) 3、fanout交换机的使用方式 3.1 声明交换机 3.2 发送消息到交换机 3.2 扇形交换机发送消息代码 3.2 声明队列,用于接收消息 3.3 binding …

sigmoid, softmax

∙ \bullet ∙ sigmoid函数 值域(0,1) 常用于二分类问题 ∙ \bullet ∙ softmax函数 每一项的区间范围的(0,1) 所有项相加的和为1. 常用于多分类问题 ∙ \bullet ∙ 区别: softmax 当类别数是2时,它退化为二项分布,而它和sigmoid真正的区别…

解决VSCode右键没有Open In Default Browser问题

在VSCode进行Web小程序测试时,我们在新建的HTML文件中输入 !会自动生成页面代码骨架,写入内容后,我们想要右键在浏览器中预览。发现右键没有“Open In Default Browser”选项。原因是没有安装插件。 下面是解决方案:首先在VSCode找…

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力 在 Gemini 时代的下一篇章中,Gemini融入了更多产品中,Android Studio 正在使用 Gemini 1.0 Pro 模型,使 Android 开发变得更快、更简单。 Studio Bot 现已更名为 And…

The Isle恐龙岛服务器开服联机教程

服务端区别:The lsle 是测试服 ;The lsle Evrima 是正式服(运行内存需要上到12G才可以运行) 1、购买后登录服务器 进入控制面板后会出现正在安装的界面,安装大约5分钟(如长时间处于安装中请联系我们的客服人…

Unity 自定义编辑器根据枚举值显示变量

public class Test : MonoBehaviour {[HideInInspector][Header("数量")][SerializeField]public int num;[Header("分布类型")][SerializeField]public DistributionType distType;[HideInInspector][Header("位置")][SerializeField]public Li…

Vue之组件基础(插槽)

在HTML中,开发者可以在双标签内添加一些信息。而在Vue中,组件以标签的形式引用,那么如何在组件的标签内添加一些信息并将信息渲染到页面中呢?其实,Vue 提供了插槽,专门用来实现这样的效果。 一.什么是插槽 Vue为组件…

视频修复工具助你完成高质量的视频作品!

在短视频发展兴起的时代,各种视频层出不穷的出现在了视野中,人们已经从追求数量转向追求质量。内容相同的视频,你视频画质好、质量高的更受大家欢迎,那么如何制作高质量、高清晰度的视频呢?与您分享三个视频修复工具。…

命名空间,缺省参数和函数重载

前言:本文章主要介绍一些C中的小语法。 目录 命名空间 namespace的使用 访问全局变量 namespace可以嵌套 不同文件中定义的同名的命名空间可以合并进一个命名空间,并且其中不可以有同名的变量 C中的输入和输出 缺省参数(默认参数&#…

电脑的kernelbase.dll故障怎么处理?kernelbase.dll是什么文件

遇到由于“kernelbase.dll”文件出错导致的应用程序崩溃或系统不稳定的问题。这种情况不仅会影响工作效率,还可能导致数据损失或更严重的系统问题。kernelbase.dll是Windows操作系统中的一个关键系统文件,它包含了多个执行基础系统功能的程序代码。因此&…

3389连接器,3389连接器如何进行安全设置

在计算机网络领域,3389端口作为Windows系统默认的远程桌面协议(RDP)端口,在远程办公、技术支持等场景中发挥着重要作用。然而,由于其广泛的使用和直接暴露在互联网上的特性,3389端口也极易成为黑客攻击的目…

python 贪心算法(Greedy Algo)

贪婪是一种算法范式,它逐步构建解决方案,始终选择提供最明显和直接收益的下一个部分。贪婪算法用于解决优化问题。 如果问题具有以下属性,则可以使用贪心法解决优化问题: 每一步,我们都可以做出当前看来最好的选择&…

git 恢复本地文件被误删除

查找自己执行命令出现的文件移除 或者创建的地方找到提交的 哈希值 然后执行 命令 git checkout c818f15(这个后面是你执行的哈希代码) main 里面有个代码值 把这个复制到你的命令行就好了 执行 然后就恢复文件了 还有一个是查找命令日志的 如果不小心…

[数据集][目标检测]水下管道泄漏破损检测数据集VOC+YOLO格式2069张2类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2069 标注数量(xml文件个数):2069 标注数量(txt文件个数):2069 标注…