深度学习:神经网络--手写数字识别

目录

一、datasets

1.datasets简介

2.主要特点

二、MNIST

三、使用神经网络实现手写数字识别

1.创建数据加载器

2.判断是否使用GPU

3.创建神经网络

4.创建训练集模型

5.创建测试集模型

6.创建损失函数和优化器并训练


一、datasets

1.datasets简介

        datasets是一个广泛使用的库,尤其在机器学习和自然语言处理领域,用于方便地加载和处理各种数据集。它提供了标准化的接口,使得数据集的访问、处理和分割变得更加简单。

 

2.主要特点

  1. 丰富的数据集库

    提供了数百个公开数据集,涵盖不同领域,如文本、图像、音频等。
  2. 简化数据加载

    只需几行代码即可加载数据集,支持多种格式(如 CSV、JSON、文本文件等)。
  3. 高效的数据处理

    提供数据预处理功能,如分词、编码、数据增强等,方便进行数据清洗和转换。
  4. 数据集分割

    支持轻松地将数据集分为训练集、验证集和测试集。
  5. 自定义数据集

    用户可以自定义数据集,方便地与现有数据集结合使用。
  6. 与深度学习框架的兼容性

    可以与 PyTorch、TensorFlow 等深度学习框架无缝集成,便于模型训练。

 

二、MNIST

  • MNIST 是一个经典的手写数字识别数据集,广泛用于机器学习和计算机视觉领域的研究和教学。
  • 它包含 70,000 张 28x28 像素的灰度图像,分为 60,000 张训练样本和 10,000 张测试样本,每张图像对应一个 0 到 9 的数字标签。

代码实现:

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor  # 数据转换,张量,将其他类型的数据转换为tensor张量'''下载训练数据集(包含训练图片和标签)'''
train_data = datasets.MNIST(root='data', train=True, download=True, transform=ToTensor()  # 张量,图片是不能直接传入神经网络模型
)  # 对于pytorch库能够识别的数据一般是tensor张量.'''下载测试数据集(包含训练图片和标签)'''
test_data = datasets.MNIST(root='data', train=False, download=True, transform=ToTensor()
)  # NumPy 数组只能在CPU上运行.Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度.
print(len(train_data))'''展示手写数字图片'''
import matplotlib.pyplot as pltfigure = plt.figure()
for i in range(9):img, label = train_data[i + 100]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(), cmap='gray')a = img.squeeze()
plt.show()

输出:

 

三、使用神经网络实现手写数字识别

1.创建数据加载器

  • 对数据进行打包
  • 通过打包数据,可以提高工作效率,简化数据管理流程,并提升模型的训练和推理性能。
'''
创建数据DataLoader(数据加载器)batch_size:将数据集分成多份,每一份为batch_size个数据.优点:可以减少内存的使用,提高训练速度.
'''
train_dataloader = DataLoader(train_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
for x, y in test_dataloader:print(f"shape of x [N ,C,H,W]:{x.shape}")print(f"shape of y :{y.shape} {y.dtype}")break

输出:

  • 第一行输出表示一个包64张一个通道的图片,每张图片像素大小为28*28
  • 第二行输出表示包的大小为64,数据类型为torch里的int64
shape of x [N ,C,H,W]:torch.Size([64, 1, 28, 28])
shape of y :torch.Size([64]) torch.int64

 

2.判断是否使用GPU

'''判断当前设备是否支持GPU mps是苹果m系列芯片GPU'''
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_avaibale() else 'cpu'
print(f"using {device} device")

输出:

using cuda device

 

3.创建神经网络

'''创建神经网络类'''class NeuralNetwork(nn.Module):def __init__(self):super().__init__()  # 继承的父类初始化self.flatten = nn.Flatten()  # 展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28 * 28, 512)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去 前一层神经元的个数,当前本层神经元个数self.hidden2 = nn.Linear(512, 256)  # 第二个隐藏层self.hidden3 = nn.Linear(256, 128)  # 第三个隐藏层self.hidden4 = nn.Linear(128, 64)  # 第四个隐藏层self.out = nn.Linear(64, 10)  # 输出必须和标签的类别相同,输入必须是上一层的神经元个数def forward(self, x):  # 前向传播,你得告诉它,数据的流向.是神经网络层连接起来,函数名称不能改.当你调用forward函数的时候,传入进来的图像数据x = self.flatten(x)  # 图像进行展开x = self.hidden1(x)x = torch.relu(x)  # 激活函数,relu,tanh,sigmod  relu没有梯度消失问题,且计算消耗大大降低x = self.hidden2(x)x = torch.relu(x)x = self.hidden3(x)x = torch.relu(x)x = self.hidden4(x)x = torch.relu(x)x = self.out(x)return xmodel = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到GPU
print(model)

输出:

  • 输出的就是神经网络的模型
  • 输入层--隐藏层1--隐藏层2--隐藏层3--隐藏层4--输出层
  • 每一层的输入神经元,输出神经元,偏置项
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=512, bias=True)(hidden2): Linear(in_features=512, out_features=256, bias=True)(hidden3): Linear(in_features=256, out_features=128, bias=True)(hidden4): Linear(in_features=128, out_features=64, bias=True)(out): Linear(in_features=64, out_features=10, bias=True)
)

 

4.创建训练集模型

def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w.在训练过程中,w会被修改的# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval().# 一般用法是: 在训练开始之前写上model.trian(),在测试时写上model.eval().batch_size_num = 1for x, y in dataloader:x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入CPU或GPUpred = model.forward(x)  # 向前传播loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 200 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

 

5.创建测试集模型

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所占用的消耗for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches  # 能来衡量模型测试的好坏。correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")

 

6.创建损失函数和优化器并训练

loss_fn = nn.CrossEntropyLoss()  # 处理多分类 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.0012)  #
# params:要训练的参数,一般我们传入的都是model.parameters()
# lr:leqrning rate学习率,也就是步长epochs = 4  # 到底选择多少呢?
for t in range(epochs):print(f"Epoch {t + 1}\n--------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

输出:

  • 可以看到每一轮之后模型损失值都在变小,说明模型在进行优化
Epoch 1
--------------
loss:0.459520 [number:200]
loss:0.307480 [number:400]
loss:0.213908 [number:600]
loss:0.196316 [number:800]
Epoch 2
--------------
loss:0.175202 [number:200]
loss:0.251854 [number:400]
loss:0.085344 [number:600]
loss:0.091463 [number:800]
Epoch 3
--------------
loss:0.050711 [number:200]
loss:0.184462 [number:400]
loss:0.043357 [number:600]
loss:0.056628 [number:800]
Epoch 4
--------------
loss:0.022849 [number:200]
loss:0.125797 [number:400]
loss:0.019434 [number:600]
loss:0.023895 [number:800]
Done!
Test result: Accuracy: 96.58%, Avg loss: 0.14490024165602944

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429701.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[mysql]mysql排序和分页

#排序和分页本身是两块内容,因为都比较简单,我们就把它分到通一个内容里. #1排序: SELECT * FROM employees #我们会发现,我们没有做排序操作,但是最后出来的107条结果还是会按顺序发出,而且是每次都一样.这我们就有一个疑惑了,现在我们的数据库是根据什么来排序的,在我们没有进…

windows 驱动实例分析系列-COM驱动案例讲解

COM也被称之为串口,这是一种非常简单的通讯接口,这种结构简单的接口被广泛的应用在开发中,几乎所有系统都能支持这种通讯接口,它有RS232和RS485等分支,但一般我们都会使用RS232作为常见的串口,因为它足够简单和高效。 几乎所有的开发板,都会提供用于烧录、调试、日志的…

redis为什么不使用一致性hash

Redis节点间通信时,心跳包会携带节点的所有槽信息,它能以幂等方式来更新配置。如果采用 16384 个插槽,占空间 2KB (16384/8);如果采用 65536 个插槽,占空间 8KB (65536/8)。 今天我们聊个知识点为什么Redis使用哈希槽而不是一致性…

FastAPI 的隐藏宝石:自动生成 TypeScript 客户端

在现代 Web 开发中,前后端分离已成为标准做法。这种架构允许前端和后端独立开发和扩展,但同时也带来了如何高效交互的问题。FastAPI,作为一个新兴的 Python Web 框架,提供了一个优雅的解决方案:自动生成客户端代码。本…

引领长期投资新篇章:价值增长与财务安全的双重保障

随着全球金融市场的不断演变,长期投资策略因其稳健性和对价值增长的显著推动作用而日益受到投资者的重视。在这一背景下,Zeal Digital Shares(ZDS)项目以其创新的数字股票产品,为全球投资者提供了一个全新的长期投资平…

re题(38)BUUCTF-[FlareOn6]Overlong

BUUCTF在线评测 (buuoj.cn) 运行一下.exe文件 查壳是32位的文件,放到ida反汇编 对unk_402008前28位进行一个操作,我们看到运行.exe文件的窗口正好是28个字符,而unk_402008中不止28个数据,所以猜测MessageBoxA(&#x…

cv中每个patch的关联

在计算机视觉任务中,当图像被划分为多个小块(patches)时,每个 patch 的关联性可以通过不同的方法来计算。具体取决于使用的模型和任务,以下是一些常见的计算 patch 关联性的方法: 1. Vision Transformer (…

Shell运行原理与Linux权限概念

shell的运行原理 Linux严格意义上说的是一个操作系统。我们称之为“核心(kernel)”,但我们一般用户,不能直接使用kernel,二十通过kernel的“外壳”程序,也就是所谓的shell,来与kernel沟通。 从…

网络穿透:TCP 打洞、UDP 打洞与 UPnP

在现代网络中,很多设备都处于 NAT(网络地址转换)或防火墙后面,这使得直接访问这些设备变得困难。在这种情况下,网络穿透技术就显得非常重要。本文将介绍三种常用的网络穿透技术:TCP 打洞、UDP 打洞和 UPnP。…

qt-C++笔记之作用等同的宏和关键字

qt-C笔记之作用等同的宏和关键字 code review! Q_SLOT 和 slots: Q_SLOT是slots的替代宏,用于声明槽函数。 Q_SIGNAL 和 signals: Q_SIGNAL类似于signals,用于声明信号。 Q_EMIT 和 emit: Q_EMIT 是 Qt 中用于发射…

18.2K Star,AI 高效视频监控摄像头

Hi,骚年,我是大 G,公众号「GitHub 指北」会推荐 GitHub 上有趣有用的项目,一分钟 get 一个优秀的开源项目,挖掘开源的价值,欢迎关注。 导语 在家庭和企业安防领域,实时视频监控是保障安全的核…

AIGC8: 高通骁龙AIPC开发者大会记录B

图中是一个小男孩在市场卖他的作品。 AI应用开发出来之后,无论是个人开发者还是企业开发者。 如何推广分发是面临的大问题。 做出来的东西一定要符合商业规律。否则就是实验室里面的玩物,或者自嗨的东西。 背景 上次是回顾和思考前面两个硬件营销总的…

【JVM】类加载

1. 类加载过程 Java虚拟机(JVM)的 类加载 过程是将字节码文件(.class文件)从存储设备加载到内存,并为其创建相应的类对象的过程。类加载是Java程序运行的基础,保证了程序的动态性和安全性。JVM的类加载过程…

人工智能 | 基于ChatGPT开发人工智能服务平台

简介 ChatGPT 在刚问世的时候,其产品形态就是一个问答机器人。而基于ChatGPT的能力还可以对其做一些二次开发和拓展。比如模拟面试功能、或者智能机器人功能。 模拟面试功能包括个性化问题生成、实时反馈、多轮面试模拟、面试报告。 智能机器人功能提供24/7客服支…

将阮一峰老师的《ES6入门教程》的源码拷贝本地运行和发布

你好同学,我是沐爸,欢迎点赞、收藏、评论和关注。 阮一峰老师的《ES6入门教程》应该是很多同学学习 ES6 知识的重要参考吧,应该也有很多同学在看该文档的时候,想知道这个教程的前端源码是怎么实现的,也可能有同学下载…

esp32 wifi 联网后,用http 发送hello 用pc 浏览器查看网页

参考chatgpt Esp32可以配置为http服务器,可以socket编程。为了免除编写针对各种操作系统的app。完全可以用浏览器仿问esp32服务器,获取esp32的各种数据,甚至esp的音频,视频。也可以利用浏览器对esp进行各种操作。但esp不能主动仿…

【医学半监督】置信度指导遮蔽学习的半监督医学图像分割

摘要: 半监督学习(Semi-supervised learning)旨在利用少数标记数据和多数未标记数据训练出高性能模型。现有方法大多采用预测任务机制,在一致性或伪标签的约束下获得精确的分割图,但该机制通常无法克服确认偏差。针对这一问题,本文提出了一种用于半监督医学图像分割的新…

【C++笔记】C++编译器拷贝优化和内存管理

【C笔记】C编译器拷贝优化和内存管理 🔥个人主页:大白的编程日记 🔥专栏:C笔记 文章目录 【C笔记】C编译器拷贝优化和内存管理前言一.对象拷贝时的编译器优化二.C/C内存管理2.1练习2.2 C内存管理方式2.3 operator new与operator…

分布式锁优化之 使用lua脚本改造分布式锁保证判断和删除的原子性(优化之LUA脚本保证删除的原子性)

文章目录 1、lua脚本入门1.1、变量:弱类型1.2、流程控制1.3、在lua中执行redis指令1.4、实战:先判断是否自己的锁,如果是才能删除 2、AlbumInfoApiController --》testLock()3、AlbumInfoServiceImpl --》testLock() 1、lua脚本入门 Lua 教程…

长亭WAF绕过测试

本文的Bypass WAF 的核心思想在于,一些 WAF 产品处于降低误报考虑,对用户上传文件的内 容不做匹配,直接放行 0、环境 环境:两台服务器,一台配置宝塔面板,一台配置长亭雷池WAF 思路主要围绕:m…