《深度学习》—— 神经网络模型对手写数字的识别

神经网络模型对手写数字的识别

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor  # 数据转换,张量,将其他类型的数据转换为tensor张量"""
MNIST包含70,000张手写数字图像:60,000张用于训练,10,000张用于测试。
图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。
"""
""" 下载训练数据集 (包含训练数据+标签)"""
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor()  # 张量,图片是不能直接传入神经网络模型
)  # 对于pytorch库能够识别的数据一般是tensor张量.
# NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。""" 下载测试数据集(包含训练图片+标签)"""
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)
print(len(training_data))""" 展示手写字图片 """
# tensor --> numpy 矩阵类型的数据
from matplotlib import pyplot as pltfigure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]  # 提取第59000张图片figure.add_subplot(3, 3, i + 1)  # 图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off")  # 关闭坐标plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()  # img.squeeze()从张量img中去掉维度为1的(降维)
plt.show()training_dataloader = DataLoader(training_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
for X, y in test_dataloader:  # X 表示打包好的每一个数据包print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型->nn.moduledef __init__(self):super().__init__()  # 继承的父类初始化self.flatten = nn.Flatten()  # 展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28 * 28, 256)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去前一层神经元的个数,当前本层神经元个数self.hidden2 = nn.Linear(256, 128)  # 输出必需和标签的类别相同,输入必须是上一层的神经元个数self.hidden3 = nn.Linear(128, 256)self.hidden4 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)#def forward(self, x):  # 前向传播,你得告诉它,数据的流向。是神经网络层连接起来,函数名称不能改。当你调用forward函数的时候,传入进来的图像数据x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)  # 激活函数x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return xmodel = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到gpu或cpu
print(model)# 定义训练模型的函数
def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.eval()。batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPUpred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 200 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1# 定义测试模型的函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches  # 衡量模型测试的好坏。correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")loss_fn = nn.CrossEntropyLoss()  # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 创建一个优化器# 设置训练轮数
epochs = 10
for e in range(epochs):print(f"Epoch {e + 1}\n")train(training_dataloader, model, loss_fn, optimizer)
print("Done!")
# 测试模型
test(test_dataloader, model, loss_fn)
  • 展示的手写数字图片如下:
    在这里插入图片描述
  • 模型结构如下:
    在这里插入图片描述
  • 训练结果如下:
  • 共有10轮训练
    在这里插入图片描述
  • 测试结果如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/427608.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式事务seata

文章目录 CAP理论BASE 理论seata解决分布式事务seata重要对象XA模式AT模式TCC模式saga模式 CAP理论 CAP理论指出在分布式系统中三个属性不可能同时满足。 Consistency 一致性:在分布式的多个节点(副本)的数据必须是一样的(强一致…

展锐平台的手机camera 系统开发过程

展锐公司有自己的isp 图像处理引擎,从2012 年底就开始在智能手机上部署应用。最初的时候就几个人做一款isp的从hal 到kernel 驱动的完整软件系统,分工不是很明确,基本是谁擅长哪些就搞哪些,除了架构和编码实现之外,另外…

软硬件项目运维方案(Doc原件完整版套用)

1 系统的服务内容 1.1 服务目标 1.2 信息资产统计服务 1.3 网络、安全系统运维服务 1.4 主机、存储系统运维服务 1.5 数据库系统运维服务 1.6 中间件运维服务 2 运维服务流程 3 服务管理制度规范 3.1 服务时间 3.2 行为规范 3.3 现场服务支持规范 3.4 问题记录规范…

【数据结构】排序算法---基数排序

文章目录 1. 定义2. 算法步骤2.1 MSD基数排序2.2 LSD基数排序 3. LSD 基数排序动图演示4. 性质5. 算法分析6. 代码实现C语言PythonJavaCGo 结语 ⚠本节要介绍的不是计数排序 1. 定义 基数排序(英语:Radix sort)是一种非比较型的排序算法&…

秋招常问的面试题:Cookie和Session的区别

《网安面试指南》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484339&idx1&sn356300f169de74e7a778b04bfbbbd0ab&chksmc0e47aeff793f3f9a5f7abcfa57695e8944e52bca2de2c7a3eb1aecb3c1e6b9cb6abe509d51f&scene21#wechat_redirect 《Java代码审…

LeetCode[中等] 3. 无重复字符的最长子串

给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 思路:滑动窗口,设置左右指针left与right,maxLength存储长度 利用HashSet性质,存储滑动窗口中的字符 如果没有重复的,那么right继续向…

LeetCode_sql_day28(1767.寻找没有被执行的任务对)

描述:1767.寻找没有被执行的任务对 表:Tasks ------------------------- | Column Name | Type | ------------------------- | task_id | int | | subtasks_count | int | ------------------------- task_id 具有唯一值的列。 ta…

无人机企业合法运营必备运营合格证详解

无人机企业运营合格证,是由国家相关航空管理部门(如中国民用航空局或其授权机构)颁发的,用于确认无人机企业具备合法、安全、高效运营无人机能力的资质证书。该证书是无人机企业开展商业运营活动的必要条件,标志着企业…

原生+jquery写自动消失的提示框

<!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>自动消失消息提示</title> <style>/…

信息安全数学基础(9)素数的算数基本定理

前言 在信息安全数学基础中&#xff0c;素数的算数基本定理&#xff08;也称为唯一分解定理或算术基本定理&#xff09;是一个极其重要的定理&#xff0c;它描述了正整数如何唯一地分解为素数的乘积。这个定理不仅是数论的基础&#xff0c;也是许多密码学算法&#xff08;如RSA…

同为TVT设备主动注册协议接入SVMSPro平台

同为设备主动注册协议接入SVMSPro平台 步骤一&#xff1a;进设备网页或者NVR配置界面&#xff0c;进功能面板&#xff0c;网络&#xff0c;平台接入 接入类型&#xff1a;平台软件&#xff0c;勾选启用主动上报 服务器地址&#xff1a;平台服务IP 端口&#xff1a;12009 ID&…

高级算法设计与分析 学习笔记6 B树

B树定义 一个块里面存了1000个数和1001个指针&#xff0c;指针指向的那个块里面的数据大小介于指针旁边的两个数之间 标准定义&#xff1a; B树上的操作 查找B树 创建B树 分割节点 都是选择正中间的那个&#xff0c;以免一直分裂。 插入数字 在插入的路上就会检查节点需不需要…

RabbitMQ 高级特性——持久化

文章目录 前言持久化交换机持久化队列持久化消息持久化 前言 前面我们学习了 RabbitMQ 的高级特性——消息确认&#xff0c;消息确认可以保证消息传输过程的稳定性&#xff0c;但是在保证了消息传输过程的稳定性之后&#xff0c;还存在着其他的问题&#xff0c;我们都知道消息…

Delphi5利用DLL实现窗体的重用

文章目录 效果图参考利用DLL实现窗体的重用步骤1 设计出理想窗体步骤2 编写一个用户输出的函数或过程&#xff0c;在其中对窗体进行创建使它实例化步骤3 对工程文件进行相应的修改以适应DLL格式的需要步骤4 编译工程文件生成DLL文件步骤5 在需要该窗体的其他应用程序中重用该窗…

不会JS逆向也能高效结合Scrapy与Selenium实现爬虫抓取

1. 创建基础的scrapy项目 1.1 基础项目 在pycharm中安装scrapy框架 pip install scrapy 创建项目 scrapy startproject 项目名称 我们现在可以看到整体文件的目录&#xff1a; firstBlood ├── firstBlood # 项目跟目录 │ ├── init.py │ ├── items.py # 封装数…

【网络】高级IO——select版本TCP服务器

目录 前言 一&#xff0c;select函数 1.1.参数一&#xff1a;nfds 1.2.参数二&#xff1a; readfds, writefds, exceptfds 1.2.1.fd_set类型和相关操作宏 1.2.2.readfds, writefds, exceptfds 1.2.3.怎么理解 readfds, writefds, exceptfds是输入输出型参数 1.3.参数三…

数据结构之二叉树遍历

二叉树的遍历 先序遍历 先输入父节点&#xff0c;再遍历左子树和右子树&#xff1a;A、B、D、E、C、F、G 中序遍历 先遍历左子树&#xff0c;再输出父节点&#xff0c;再遍历右子树&#xff1a;D、B、E、A、F、C、G 后序遍历 先遍历左子树&#xff0c;再遍历右子树&#xff0c;…

SpringBoot设置mysql的ssl连接

因工作需要&#xff0c;mysql连接需要开启ssl认证&#xff0c;本文主要讲述客户端如何配置ssl连接。 开发环境信息&#xff1a; SpringBoot&#xff1a; 2.0.5.RELEASE mysql-connector-java&#xff1a; 8.0.18 mysql version&#xff1a;8.0.18 一、检查服务端是否开启ssl认…

微信公众号开发入门

微信公众号开发是指开发者基于微信公众平台&#xff08;WeChat Official Accounts Platform&#xff09;所提供的接口与功能&#xff0c;开发和构建自定义的功能与服务&#xff0c;以满足企业、组织或个人在微信生态中的应用需求。微信公众号开发主要围绕公众号消息处理、菜单管…

K1计划100%收购 MariaDB; TDSQL成为腾讯云核心战略产品; Oracle@AWS/Google/Azure发布

重要更新 1. 腾讯全球数字生态大会与9月5日-6日举行&#xff0c;发布“5T”战略&#xff0c;包括TDSQL、TencentOS、TCE&#xff08;专有云 &#xff09;、TBDS&#xff08;大数据&#xff09;、TI &#xff08;人工智能开发平台&#xff09;等 ( [2] ) ; 并正式向原子开源基金…