实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as npimport einops.layers.torch as elt#载入数据x_train = np.load("../dataset/mnist/x_train.npy")y_train_label = np.load("../dataset/mnist/y_train_label.npy")x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()#前置的特征提取模块self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层nn.ReLU(),nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层nn.ReLU(),nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层)#最终分类器层self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),nn.ReLU(),nn.Conv2d(12,24,kernel_size=5),nn.ReLU(),nn.Conv2d(24,6,kernel_size=3))self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):train_num = len(x_train)//128train_loss = 0.for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizex_batch = torch.tensor(x_train[start:end]).to(device)y_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(x_batch)loss = loss_fn(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/116671.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

控制goroutine 的并发执行数量

goroutine的数量上限是1048575吗? 正常项目,协程数量超过十万就需要引起重视。如果有上百万goroutine,一般是有问题的。 但并不是说协程数量的上限是100多w 1048575的来自类似如下的demo代码: package mainimport ( "fmt" "ma…

MySQL的mysql-bin.00xx binlog日志文件的清理

目录 引言手工清理配置自动清理 引言 公司一个项目生产环境mysql数据盘占用空间增长得特别快,经过排查发现是开启了mysql的binlog日志。如果把binlog日志关闭,如果操作万一出现问题,就没有办法恢复数据,很不安全,只能…

WPF怎么实现文件拖放功能winform怎么实现拖拽功能

WPF怎么实现文件拖放功能winform怎么实现文件拖拽功能,在管理员模式下wpf winform怎么实现文件的拖拽功能 WPF实现文件拖放功能,正常情况并没有什么问题,但是如果你的程序使用管理员身份启动,你就会发现文件拖放功能就会失效。同…

css元素定位:通过元素的标签或者元素的id、class属性定位

前言 大部分人在使用selenium定位元素时,用的是xpath元素定位方式,因为xpath元素定位方式基本能解决定位的需求。xpath元素定位方式更直观,更好理解一些。 css元素定位方式往往被忽略掉了,其实css元素定位方式也有它的价值&…

全新纠错码将量子计算提效10倍!

上周,来自两个研究小组的最新模拟报告称,一类新兴的量子纠错码的效率比目前的“黄金标准”(即表面码)高出一个数量级。 量子纠错码的工作原理都是将大量容易出错的量子比特转换成更小的“受保护”量子比特,这些量子比特…

前端Vue仿企查查天眼查高管信息列表组件

随着技术的不断发展,传统的开发方式使得系统的复杂度越来越高。在传统开发过程中,一个小小的改动或者一个小功能的增加可能会导致整体逻辑的修改,造成牵一发而动全身的情况。为了解决这个问题,我们采用了组件化的开发模式。通过组…

Ansible学习笔记8

group模块: 创建一个group组: [rootlocalhost ~]# ansible group1 -m group -a "nameaaa gid5000" 192.168.17.105 | CHANGED > {"ansible_facts": {"discovered_interpreter_python": "/usr/bin/python"}…

穿上App外衣,保持Web灵魂——PWA温故

早在2015年,设计师弗朗西斯贝里曼和Google Chrome的工程师亚历克斯罗素提出“PWA(渐进式网络应用程序)”概念,将网络之长与应用之长相结合,其核心目标就是提升 Web App 的性能,改善 Web App以媲美Native的流…

HttPClient简介及示例:学习如何与Web服务器进行通信

文章目录 前言一、引入依赖二、使用步骤1.创建被调用者2.创建调用者三、结果被调用者服务:调用者服务: 总结 前言 欢迎来到本篇博客,这是一个关于HttPClient的入门案例的指南。🎉 在今天的网络世界中,与服务器进行数据…

qt.qpa.plugin:找不到Qt平台插件“wayland“|| (下载插件)Ubuntu上解决方案

相信大家也都知道这个地方应该做什么,当然是下载这个qt平台的插件wayland,但是很多人可能不知道怎么下载这个插件。 那么我现在要说的这个方法就是针对这种的。 sudo apt install qtwayland5完事儿了奥兄弟们。 看看效果 正常了奥。

FFmpeg5.0源码阅读——FFmpeg大体框架(以GIF转码为示例)

摘要:前一段时间熟悉了下FFmpeg主流程源码实现,对FFmpeg的整体框架有了个大概的认识,因此在此做一个笔记,希望以比较容易理解的文字描述FFmpeg本身的结构,加深对FFmpeg的框架进行梳理加深理解,如果文章中有…

java从入门到起飞(六)——用Socket实现网络通信

文章目录 背景网络编程网络编程三要素 2.DatagramSocket之UDP通信程序2.1 UDP发送数据2.2UDP接收数据2.3 3. Socket之TCP通信程序3.1TCP发送数据3.2TCP接收数据 背景 网络编程 ● 计算机网络 是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线…

Matlab图像处理-加法运算

加法运算 图像加法运算的一个应用是将一幅图像的内容叠加到另一幅图像上,生成叠加图像效果,或给图像中每个像素叠加常数改变图像的亮度。 在MATLAB图像处理工具箱中提供的函数imadd()可实现两幅图像的相加或者一幅图像和常量的相加。 程序代码 I1 i…

剑指 Offer 44. 数字序列中某一位的数字(中等)

题目: class Solution { //本题单纯找规律,要注意通过n%digits来判断有几个位数为digits的数 public:int findNthDigit(int n) {long base 9, digits 1; //digits代表位数while(n-base*digits>0){ //该循环是为了确定目标数字所在…

指针(一)------指针概念+指针类型+野指针+指针运算+二级指针

💓博主csdn个人主页:小小unicorn ⏩专栏分类:C语言 🚚代码仓库:小小unicorn的代码仓库🚚 🌹🌹🌹关注我带你学习编程知识 指针(一) 指针是什么指针…

WordPress关注公众号可见内容插件源码

Wordpress公众号引流工具——关注公众号可见内容插件推荐 通过关注微信公众号,获取随机验证码从而获得隐藏文本的访问权限。 插件特点 隐藏内容扫码关注获取验证码 可以作为引流公众号 支持无必须API接口,无备案域名也可以 自定义验证接口URL 自定…

java八股文面试[数据库]——B树和B+树的区别

B树是一种树状数据结构,它能够存储数据、对其进行排序并允许以O(logn)的时间复杂度进行查找、顺序读取、插入和删除等操作。 1、B树的特性 B树中允许一个结点中包含多个key,可以是3个、4个、5个甚至更多,并不确定,需要看具体的实…

Linux——常用命令大汇总(带你快速入门Linux)

纵有疾风起,人生不言弃。本文篇幅较长,如有错误请不吝赐教,感谢支持。 💬文章目录 一.终端和shell命令解析器终端和shell命令解析器概述终端提示符的格式常用快捷键 二.Linux命令格式帮助文档:man 三.目录基础知识Wind…

LabVIEW是如何控制硬件的?

概述 工程 师 和 科学 家 可以 使用 LabVIEW 与 数千 种 不同 的 硬件 设备 无缝 集成, 并 通过 方便 的 功能 和 跨 所有 硬件 的 一致 编 程 框架 帮助 节省 开发 时间。 内容 通过更简单的系统集成节省开发时间 连接到任何硬件 NI 硬件 第三方硬件 快速找到…

基础知识回顾:借助 SSL/TLS 和 NGINX 进行 Web 流量加密

原文作者: Robert Haynes 原文链接: 基础知识回顾:借助 SSL/TLS 和 NGINX 进行 Web 流量加密 NGINX 唯一中文官方社区 ,尽在 nginx.org.cn 网络攻击者肆无忌惮、作恶多端,几乎每天都有网络入侵、数据窃取或勒索软件攻击…