Pytorch-以数字识别更好地入门深度学习

目录

一、数据介绍

二、下载数据 

三、可视化数据

四、模型构建

五、模型训练

六、模型预测


一、数据介绍

MNIST数据集是深度学习入门的经典案例,因为它具有以下优点:

1. 数据量小,计算速度快。MNIST数据集包含60000个训练样本和10000个测试样本,每张图像的大小为28x28像素,这样的数据量非常适合在GPU上进行并行计算。

2. 标签简单,易于理解。MNIST数据集的标签只有0-9这10个数字,相比其他图像分类数据集如CIFAR-10等更加简单易懂。

3. 数据集已标准化。MNIST数据集中的图像已经被归一化到0-1之间,这使得模型可以更快地收敛并提高准确率。

4. 适合初学者练习。MNIST数据集是深度学习入门的最佳选择之一,因为它既不需要复杂的数据预处理,也不需要大量的计算资源,可以帮助初学者快速上手深度学习。

综上所述,MNIST数据集是深度学习入门的经典案例,它具有数据量小、计算速度快、标签简单、数据集已标准化、适合初学者练习等优点,因此被广泛应用于深度学习的教学和实践中。

手写数字识别技术的应用非常广泛,例如在金融、保险、医疗、教育等领域中,都有很多应用场景。手写数字识别技术可以帮助人们更方便地进行数字化处理,提高工作效率和准确性。此外,手写数字识别技术还可以用于机器人控制、智能家居等方面  。

使用torch.datasets.MNIST下载到指定目录下:./data,当download=True时,如果已经下载了不会再重复下载,同train选择下载训练数据还是测试数据

官方提供的类:

class MNIST(root: str,train: bool = True,transform: ((...) -> Any) | None = None,target_transform: ((...) -> Any) | None = None,download: bool = False
)
Args:root (string): Root directory of dataset where MNIST/raw/train-images-idx3-ubyteand MNIST/raw/t10k-images-idx3-ubyte exist.train (bool, optional): If True, creates dataset from train-images-idx3-ubyte,otherwise from t10k-images-idx3-ubyte.download (bool, optional): If True, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is not downloaded again.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, transforms.RandomCroptarget_transform (callable, optional): A function/transform that takes in thetarget and transforms it.

二、下载数据 

# 导入数据集
# 训练集
import torch
from torchvision import datasets,transforms
from torch.utils.data import Dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root="./data",train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),batch_size=64,shuffle=True)# 测试集
test_loader = torch.utils.data.DataLoader(datasets.MNIST("./data",train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),batch_size=64,shuffle=True
)

pytorch也提供了自定义数据的方法,根据自己数据进行处理

使用PyTorch提供的Dataset和DataLoader类来定制自己的数据集。如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。

以下是一个简单的例子,展示如何创建一个自定义的数据集并将其传递给模型进行训练:

import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):x = self.data[index]y = self.labels[index]return x, ydata = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))my_dataset = MyDataset(data, labels)
my_dataloader = DataLoader(my_dataset, batch_size=4, shuffle=True)

详细完整流程可参考: Pytorch快速搭建并训练CNN模型?

三、可视化数据

mport matplotlib.pyplot as plt
import numpy as np
import torchvision
def imshow(img):img = img / 2 + 0.5 # 逆归一化npimg = img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))plt.title("Label")plt.show()# 得到batch中的数据
dataiter = iter(train_loader)
images,labels = next(dataiter)
# 展示图片
imshow(torchvision.utils.make_grid(images))

四、模型构建

定义模型类并继承nn.Module基类

# 构建模型
import torch.nn as nn
import torch
import torch.nn.functional as F
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()# 输入图像为单通道,输出为六通道,卷积核大小为5×5self.conv1 = nn.Conv2d(1,6,5)self.conv2 = nn.Conv2d(6,16,5)# 将16×4×4的Tensor转换为一个120维的Tensor,因为后面需要通过全连接层self.fc1 = nn.Linear(16*4*4,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):# 在(2,2)的窗口上进行池化x = F.max_pool2d(F.relu(self.conv1(x)),2)x = F.max_pool2d(F.relu(self.conv2(x)),2)# 将维度转换成以batch为第一维,剩余维数相乘为第二维x = x.view(-1,self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self,x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = MyNet()
print(net)

输出: 

MyNet((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=256, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

简单的前向传播

# 前向传播
print(len(images))
image = images[:2]
label = labels[:2]
print(image.shape)
print(image.size())
print(label)
out = net(image)
print(out)

输出: 

16
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
tensor([6, 0])
tensor([[ 1.5441e+00, -1.2524e+00,  5.7165e-01, -3.6299e+00,  3.4144e+00,2.7756e+00,  1.1974e+01, -6.6951e+00, -1.2850e+00, -3.5383e+00],[ 6.7947e+00, -7.1824e+00,  8.8787e-01, -5.2218e-01, -4.1045e+00,4.6080e-01, -1.9258e+00,  1.8958e-01, -7.7214e-01, -6.3265e-03]],grad_fn=<AddmmBackward0>)

计算损失:

# 计算损失
# 因为是多分类,所有采用CrossEntropyLoss函数,二分类用BCELoss
image = images[:2]
label = labels[:2]
out = net(image)
criterion = nn.CrossEntropyLoss()
loss = criterion(out,label)
print(loss)

输出:

tensor(2.2938, grad_fn=<NllLossBackward0>)

五、模型训练

# 开始训练
model = MyNet()
# device = torch.device("cuda:0")
# model = model.to(device)
import torch.optim as optim
optimizer = optim.SGD(model.parameters(),lr=0.01) # lr表示学习率
criterion = nn.CrossEntropyLoss()
def train(epoch):# 设置为训练模式:某些层的行为会发生变化(dropout和batchnorm:会根据当前批次的数据计算均值和方差,加速模型的泛化能力)model.train()running_loss = 0.0for i,data in enumerate(train_loader):# 得到输入和标签inputs,labels = data# 消除梯度optimizer.zero_grad()# 前向传播、计算损失、反向传播、更新参数outputs = model(inputs)loss = criterion(outputs,labels)loss.backward()optimizer.step()# 打印日志running_loss += loss.item()if i % 100 == 0:print("[%d,%5d] loss: %.3f"%(epoch+1,i+1,running_loss/100))running_loss = 0train(10)

输出:

[11,    1] loss: 0.023
[11,  101] loss: 2.302
[11,  201] loss: 2.294
[11,  301] loss: 2.278
[11,  401] loss: 2.231
[11,  501] loss: 1.931
[11,  601] loss: 0.947
[11,  701] loss: 0.601
[11,  801] loss: 0.466
[11,  901] loss: 0.399

六、模型预测

# 模型预测结果
correct = 0
total = 0
with torch.no_grad():for data in test_loader:images,labels = dataoutputs = model(images)# 最大的数值及最大值对应的索引value,predicted = torch.max(outputs.data,1)total += labels.size(0)# 对bool型的张量进行求和操作,得到所有预测正确的样本数,采用item将整数类型的张量转换为python中的整型对象correct += (predicted == labels).sum().item()print("predicted:{}".format(predicted[:10].tolist()))print("label:{}".format(labels[:10].tolist()))print("Accuracy of the network on the 10 test images: %d %%"% (100*correct/total))imshow(torchvision.utils.make_grid(images[:10],nrow=len(images[:10])))

输出:

predicted:[1, 0, 7, 6, 5, 2, 4, 3, 2, 6]
label:[1, 0, 7, 6, 5, 2, 4, 8, 2, 6]
Accuracy of the network on the 10 test images: 91 %

对应类别的准确率:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
classes = [i for i in range(10)]with torch.no_grad():# model.eval()for data in test_loader:images,labels = dataoutputs = model(images)value,predicted = torch.max(outputs,1)c = (predicted == labels).squeeze()# 对所有labels逐个进行判断for i in range(len(labels)):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1print("class_correct:{}".format(class_correct))print("class_total:{}".format(class_total))# 每个类别的指标
for i in range(10):print('Accuracy of -> class %d : %2d %%'%(classes[i],100*class_correct[i]/class_total[i]))

输出:

class_correct:[958.0, 1119.0, 948.0, 938.0, 901.0, 682.0, 913.0, 918.0, 748.0, 902.0]
class_total:[980.0, 1135.0, 1032.0, 1010.0, 982.0, 892.0, 958.0, 1028.0, 974.0, 1009.0]
Accuracy of -> class 0 : 97 %
Accuracy of -> class 1 : 98 %
Accuracy of -> class 2 : 91 %
Accuracy of -> class 3 : 92 %
Accuracy of -> class 4 : 91 %
Accuracy of -> class 5 : 76 %
Accuracy of -> class 6 : 95 %
Accuracy of -> class 7 : 89 %
Accuracy of -> class 8 : 76 %
Accuracy of -> class 9 : 89 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/111240.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【复杂网络建模】——ER网络和SF网络的阈值分析

目录 1、介绍ER网络和SF网络 2、计算网络阈值 2.1 ER&#xff08;Erdős-Rnyi&#xff09;网络 2.2 SF&#xff08;Scale-Free&#xff09;网络 3、 研究网络阈值的意义 1、介绍ER网络和SF网络 在复杂网络理论中&#xff0c;ER网络&#xff08;Erdős-Rnyi网络&#xff…

mybatis:动态sql【2】+转义符+缓存

目录 一、动态sql 1.set、if 2.foreach 二、转义符 三、缓存cache 1. 一级缓存 2. 二级缓存 一、动态sql 1.set、if 在update语句中使用set标签&#xff0c;动态更新set后的sql语句&#xff0c;&#xff0c;if作为判断条件。 <update id"updateStuent" pa…

【USRP】集成化仪器系列3 :频谱仪,基于labview实现

USRP 频谱仪 1、设备IP地址&#xff1a;默认为192.168.10.2&#xff0c;请勿 修改&#xff0c;运行阶段无法修改。 2、天线输出端口是TX1&#xff0c;请勿修改。 3、通道&#xff1a;0 对应RF A、1 对应 RF B&#xff0c;运行 阶段无法修改。 4、中心频率&#xff1a;当需要…

Unity 结构少继承多组合

为什么不推荐使用继承&#xff1f; 继承是面向对象的四大特性之一&#xff0c;用来表示类之间的 is-a 关系&#xff0c;可以解决代码复用的问题。虽然继承有诸多作用&#xff0c;但继承层次过深、过复杂&#xff0c;也会影响到代码的可维护性。所以&#xff0c;对于是否应该在…

程序员为什么要写bug,不能一次性写好吗?

仅仅听到“Bug”这个词就会让你作为一个开发人员感到畏缩。我们相信&#xff0c;优秀的程序员是那些编写无错误代码的人。随着一些开发人员强调要成为一名零错误程序员&#xff0c;我们进行了更深刻的思考&#xff0c;并发现事实的准确性。 所有制作的软件都应该没有错误。对此…

强化学习笔记

马尔科夫决策过程 markov chain&#xff1a; S \mathcal{S} S MRP&#xff1a; S &#xff0c; R \mathcal{S&#xff0c;R} S&#xff0c;R MDP&#xff1a; S &#xff0c; A ( s ) &#xff0c; R &#xff0c; P \mathcal{S&#xff0c;A(s)&#xff0c;R&#xff0c;P} …

在 Redis 中处理键值 | Navicat

Redis 是一个键值存储系统&#xff0c;允许我们将值与键相关联起来。与关系型数据库不同的是&#xff0c; 在Redis 中&#xff0c;不需要使用数据操作语言 &#xff08;DML&#xff09; 和查询语法&#xff0c;那么我们如何进行数据的写入、读取、更新和删除操作呢&#xff1f;…

怎么检测UI卡顿?(线上及线下)

什么是UI卡顿&#xff1f; 在Android系统中&#xff0c;我们知道UI线程负责我们所有视图的布局&#xff0c;渲染工作&#xff0c;UI在更新期间&#xff0c;如果UI线程的执行时间超过16ms&#xff0c;则会产生丢帧的现象&#xff0c;而大量的丢帧就会造成卡顿&#xff0c;影响用…

prometheus + grafana进行服务器资源监控

在性能测试中&#xff0c;服务器资源是值得关注一项内容&#xff0c;目前&#xff0c;市面上已经有很多的服务器资 源监控方法和各种不同的监控工具&#xff0c;方便在各个项目中使用。 但是&#xff0c;在性能测试中&#xff0c;究竟哪些指标值得被关注呢&#xff1f; 监控有…

springcloud-gateway简述

Spring Cloud Gateway 是一个用于构建 API 网关的项目&#xff0c;它是 Spring Cloud 生态系统中的一部分&#xff0c;旨在为微服务架构提供动态路由、负载均衡、安全性和监控等功能。 网关工程对应pom文件 <?xml version"1.0" encoding"UTF-8"?>…

【ag-grid-vue】基本使用

ag-grid是一款功能和性能强大外观漂亮的表格插件&#xff0c;ag-grid几乎能满足你对数据表格所有需求。固定列、拖动列大小和位置、多表头、自定义排序等等各种常用又必不可少功能。关于收费的问题&#xff0c;绝大部分应用用免费的社区版就够了&#xff0c;ag-grid-community社…

在线设计APP ui的网站,分享这7款

在数字时代&#xff0c;用户界面&#xff08;UI&#xff09;设计变得非常重要&#xff0c;因为良好的UI设计可以改善用户体验&#xff0c;增强产品吸引力。随着科学技术的发展&#xff0c;越来越多的应用在线设计网站出现&#xff0c;为设计师和团队提供了一种新的创作方式。本…

【大数据知识】大数据平台和数据中台的定义、区别以及联系

数据行业有太多数据名词&#xff0c;例如大数据、大数据平台、数据中台、数据仓库等等。但大家很容易混淆&#xff0c;也很容易产生疑问&#xff0c;今天我们就来简单聊聊大数据平台和数据中台的定义、区别以及联系。 大数据平台和数据中台的定义 大数据平台&#xff1a;一个…

阿里云大数据实战记录8:拆开 json 的每一个元素,一行一个

目录 一、前言二、目标介绍三、使用 pgsql 实现3.1 拆分 content 字段3.2 拆分 level 字段3.3 拼接两个拆分结果 四、使用 ODPS SQL 实现4.1 拆分 content 字段4.2 拆分 level 字段4.3 合并拆分 五、使用 MySQL 实现六、总结 一、前言 商业场景中&#xff0c;经常会出现新的业…

JUC并发编程--------基础篇

一、多线程的相关知识 栈与栈帧 我们都知道 JVM 中由堆、栈、方法区所组成&#xff0c;其中栈内存是给谁用的呢&#xff1f;其实就是线程&#xff0c;每个线程启动后&#xff0c;虚拟 机就会为其分配一块栈内存。 每个栈由多个栈帧&#xff08;Frame&#xff09;组成&#xf…

极智嘉(Geek+)再获重磅荣誉,持续力领跑智慧物流行业发展

近日&#xff0c;全球仓储机器人引领者极智嘉(Geek)再度传来好消息&#xff0c;凭借着全球化的专业服务能力和稳健增长的亮眼海外成绩&#xff0c;一举荣登“2023出海品牌服务商”价值榜&#xff0c;成为唯一登榜的物流机器人企业。 作为率先出海的物流机器人企业&#xff0c…

博客系统后端(项目系列2)

目录 前言 &#xff1a; 1.准备工作 1.1创建项目 1.2引入依赖 1.3创建必要的目录 2.数据库设计 2.1博客数据 2.2用户数据 3.封装数据库 3.1封装数据库的连接操作 3.2创建两个表对应的实体类 3.3封装一些必要的增删改查操作 4.前后端交互逻辑的实现 4.1博客列表页 …

十年测试工程师叙述自动化测试学习思路

自动化测试介绍 自动化测试(Automated Testing)&#xff0c;是指把以人为驱动的测试行为转化为机器执行的过程。实际上自动化测试往往通过一些测试工具或框架&#xff0c;编写自动化测试用例&#xff0c;来模拟手工测试过程。比如说&#xff0c;在项目迭代过程中&#xff0c;持…

C++中的虚继承、多态以及模板的介绍

菱形继承 概念 菱形继承又称为钻石继承&#xff0c;由公共基类派生出多个中间子类&#xff0c;又由中间子类共同派生出汇聚子类。汇聚子类会得到中间子类从公共基类继承下来的多份成员 格式 A --------公共基类/ \B C ------- 中间子类\ /D -----…