6.6 实现卷积神经网络LeNet训练并预测手写体数字

模型架构

在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),#padding=2补偿5x5卷积核导致的特征减少。nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)
'''定义X,并打印模型的形状'''
# 第一个参数是样本
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
# 输出如下:
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
'''定义训练批次并加载训练集和测试集'''
batch_size = 256
# 按照batch_size把数据集取出来。取出来之后是放到内存中的,后面要把它加载到GPU中
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 计算预测正确的个数
def accuracy(y_hat,y):'''计算预测正确的数量'''if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:# y_hat是下标表示类别,值是该类别的概率。模型结果是预测10个类的概率,谁的概率最大,就取谁的下标y_hat = y_hat.argmax(axis=1)#  y_hat.type(y.dtype):因为==对数据类型很敏感,因此我们将y_hat的数据类型转换为与y的数据类型一致。#  y_hat.type(y.dtype) == y,将预测值y_hat与真实值y比较,返回一个包含 0和1的张量,赋值给cam,最后求和会得到正确预测的数量。cam = y_hat.type(y.dtype) == yreturn float(cam.type(y.dtype).sum())
def evaluate_accuracy_gpu(net,data_iter,device=None):if isinstance(net,nn.Module):net.eval() # 将模型设置为评估模式if not device:'''iter(net.parameters())是将参数集合转换为迭代器,并获取其中的第一个元素next(iter(net.parameters())).device ,指取到net.parameters()的第一个元素,获取该元素的设备。'''device = next(iter(net.parameters())).device# Accumulator用于对多个变量进行累加,d2l.Accumulator(2) 是在Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都随着时间的推移而累加。metric = d2l.Accumulator(2) # 正确预测数,预测总数with torch.no_grad():for X,y in data_iter:if isinstance(X,list): # 详见文章最下面的补充内容X = [x.to(device) for x in X] # 令X使用设备deviceelse:X = X.to(device)y = y.to(device)# y.numel()是批次中样本的数量,accuracy(net(X),y)是用于计算模型在输入数据X上的输出结果与标签Y之间的准确率。# metric.add函数将正确预测的数量 和 样本数量作为参数传递进去,用于记录和累计这些指标的值。metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] # 返回准确率,其中metric[0]存放的是正确预测的个数,metric[1]存放的是样本数量,
def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d: # 对神经网络中的线性层和卷积层的权重进行初始化nn.init.xavier_uniform_(m.weight) #用于初始化权重的函数,net.apply(init_weights)print('training on',device)net.to(device) # 设置模型使用deviceoptimizer = torch.optim.SGD(net.parameters(),lr=lr)loss = nn.CrossEntropyLoss()'''该代码创建了一个名为animator的动画器,用于在训练过程中可视化损失函数和准确率的变化情况'''animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss','train acc','test acc'])timer,num_batches = d2l.Timer(),len(train_iter)for epoch in range(num_epochs):# 创建 Accumulator类,统计训练损失之和,正确预测个数之和,样本数metric = d2l.Accumulator(3)net.train()for i,(X,y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X,y = X.to(device),y.to(device)y_hat = net(X)l = loss(y_hat,y)l.backward()optimizer.step()with torch.no_grad():metric.add(l*X.shape[0], d2l.accuracy(y_hat,y), X.shape[0])timer.stop()train_l = metric[0] / metric[2] # 损失之和 / 样本数train_acc = metric[1] / metric[2] # 正确预测个数 / 样本数if (i+1) % (num_batches//5)==0 or i == num_epochs-1:animator.add(epoch + (i+1)/num_epochs,(train_l,train_acc,None))test_acc = evaluate_accuracy_gpu(net,test_iter)animator.add(epoch+1,(None,None,test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
# 定义学习率和批次 开始训练
lr,num_epochs = 0.9,10
train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

使用训练好的模型进行预测

y_hat = net(x)

补充:

isinstance(net,nn.Module)

isinstance(net,nn.Module)是Python的内置函数,用于判断一个对象是否属于制定类或其子类的实例。如果net是nn.Module类或子类的实例,那么表达式返回True,否则返回False. nn.Module是pytorch中用于构建神经网络模型的基类,其他神经网络都会继承它,因此使用 isinstance(net,nn.Module),可以确定Net对象是否为一个有效的神经网络模型。

`nn.init.xavier_uniform_(m.weight)

nn.init.xavier_uniform_(m.weight) 是一个用于初始化权重的函数,采用的是 Xavier 均匀分布初始化方法。

在神经网络中,权重的初始化非常重要,合适的初始化可以帮助网络更好地学习和收敛。Xavier 初始化方法是一种常用的权重初始化方法之一,旨在使权重在前向传播过程中保持方差不变。

具体而言,nn.init.xavier_uniform_() 函数会对输入的权重张量 m.weight 进行操作,将其初始化为一个均匀分布中的随机值。这个均匀分布的范围根据权重张量的形状进行调整,以保持前向传播过程中特征的方差稳定。

通过使用 Xavier 初始化方法,可以加速神经网络的训练过程,并且有助于避免梯度消失或梯度爆炸等问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/79234.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

黑马大数据学习笔记5-案例

目录 需求分析背景介绍目标需求数据内容DBeaver连接到Hive建库建表加载数据 ETL数据清洗数据问题需求实现查看结果扩展 指标计算需求需求指标统计 可视化展示BIFineBI的介绍及安装FineBI配置数据源及数据准备 可视化展示 P73~77 https://www.bilibili.com/video/BV1WY4y197g7?…

宋浩概率论笔记(三)随机向量/二维随机变量

第三更:本章的内容最重要的在于概念的理解与抽象,二重积分通常情况下不会考得很难。此外,本次暂且忽略【二维连续型随机变量函数的分布】这一章节,非常抽象且难度较高,之后有时间再更新。

回归决策树模拟sin函数

# -*-coding:utf-8-*- import numpy as np from sklearn import tree import matplotlib.pyplot as pltplt.switch_backend("TkAgg") # 创建了一个随机数生成器对象 rng rngnp.random.RandomState(1) print("rng",rng) #5*rng.rand(80,1)生成一个80行、1列…

恒盛策略:上交所过户费收费标准?

随着我国股市的发展,股票买卖所的过户费成为了广阔投资者关注的焦点之一。在我国股票商场中,上交所是最重要的买卖所之一,因而上交所过户费的收费规范备受到了广泛关注。那么,上交所过户费的收费规范究竟怎么拟定?对投…

【Docker】Docker私有仓库的使用

目录 一、搭建私有仓库 二、上传镜像到私有仓库 三、从私有仓库拉取镜像 一、搭建私有仓库 首先我们需要拉取仓库的镜像 docker pull registry 然后创建私有仓库容器 docker run -it --namereg -p 5000:5000 registry 这个时候我们可以打开浏览器访问5000端口看是否成功&…

python-opencv对极几何 StereoRectify

OpenCV如何正确使用stereoRectify函数 函数介绍 用于双目相机的立体校正环节中,这里只谈谈这个函数怎么使用,参数具体指哪些函数参数 随便去网上一搜或者看官方手册就能得到参数信息,但是!!相对关系非常容易出错&…

【MySQL】事务的多版本并发控制(MVCC)

目录 一、数据库并发的三种场景二、MVCC2.1 三个记录隐藏字段2.2 undo log(撤销日志)2.3 模拟MVCC2.3.1 模拟更新(update)2.3.1 模拟删除(delete)2.3.1 模拟插入(insert)2.3.1 模拟查…

maven中常见问题

文章目录 一、配置项提示二、父子打包三、打包之后不显示target四、自定义打包之后的jar包名称五、整个项目打包5.1、父项目管理插件和微服务打包 一、配置项提示 SpringBoot中提示错误信息 表示的是SpringBoot中的注释提示没有配置!那么可以来使用一下springboot官…

安全学习DAY14_JS信息打点

信息打点——前端JS框架 文章目录 信息打点——前端JS框架小节概述-思维导图JS安全概述什么是JS渗透测试?前后端差异JS安全问题流行的Js框架如何判定JS开发应用? 测试方法(JS文件的获取以及分析方法1、手工搜索分析2、半自动Burp分析插件介绍…

problem(3):python IDE和python解释器

为什么写这篇文章呢?遇到了下面的问题,相同的解释器,如果运行angr库的代码,会出现 这样的情况,但是用spyder IDE 会显示正常,很奇怪 应该就是IDE的原因 IDE的循环导入问题 检查IDE配置: 如果可…

引流精准客源方法,学会这一招就够你用的了

科思创业汇 大家好,这里是科思创业汇,一个轻资产创业孵化平台。赚钱的方式有很多种,我希望在科思创业汇能够给你带来最快乐的那一种! 第一,你要想一想,你想吸引什么样的人? 您的排水目的是推…

构建语言模型:BERT 分步实施指南

学习目标 了解 BERT 的架构和组件。了解 BERT 输入所需的预处理步骤以及如何处理不同的输入序列长度。获得使用 TensorFlow 或 PyTorch 等流行机器学习框架实施 BERT 的实践知识。了解如何针对特定下游任务(例如文本分类或命名实体识别)微调 BERT。为什么我们需要 BERT? 正…

Vue3+SpringBoot快速开发模板

起因:个人开发过程经常会使用到Vue3SpringBoot技术栈来开发项目,每次在项目初始化时都需要涉及一些重复的整理工作,于是结合一些个人觉得不错的前后端模板进行整合,打通一些大多数项目都需要的实现的基础功能,以便于快…

Spring 事务管理

目录 1. 事务管理 1.1. Spring框架的事务支持模型的优势 1.1.1. 全局事务 1.1.2. 本地事务 1.1.3. Spring框架的一致化编程模型 1.2. 了解Spring框架的事务抽象(Transaction Abstraction) 1.2.1. Hibernate 事务设置 1.3. 用事务同步资源 1.3.1…

协议,序列化,反序列化,Json

文章目录 协议序列化和反序列化网络计算器protocol.hppServer.hppServer.ccClient.hppClient.cclog.txt通过结果再次理解通信过程 Json效果 协议 协议究竟是什么呢?首先得知道主机之间的网络通信交互的是什么数据,像平时使用聊天APP聊天可以清楚&#x…

springboot 对接 minio 分布式文件系统

1. minio介绍 Minio 是一个基于Go语言的对象存储服务。它实现了大部分亚马逊S3云存储服务接口,可以看做是是S3的开源版本,非常适合于存储大容量非结构化的数据,例如图片、视频、日志文件、备份数据和容器/虚拟机镜像等,而一个对象…

npm install时出现的问题Failed at the node-sass@4.14.1 postinstall script

从阿里云上拉取下来项目后,首先使用npm install 命令进行安装所需依赖,意想不到的事情发生了,报出了Failed at the node-sass4.14.1 postinstall script,这个问题,顿时一脸懵逼;询问前端大佬,给…

内存快照:宕机后,Redis如何实现快速恢复?RDB

AOF的回顾 回顾Redis 的AOF的持久化机制。 Redis 避免数据丢失的 AOF 方法。这个方法的好处,是每次执行只需要记录操作命令,需要持久化的数据量不大。一般而言,只要你采用的不是 always 的持久化策略,就不会对性能造成太大影响。 …

CS 144 Lab Six -- building an IP router

CS 144 Lab Six -- building an IP router 引言路由器的实现测试 对应课程视频: 【计算机网络】 斯坦福大学CS144课程 Lab Six 对应的PDF: Lab Checkpoint 5: building an IP router 引言 在本实验中,你将在现有的NetworkInterface基础上实现一个IP路由器&#xf…

scala连接mysql数据库

scala中通常是通过JDBC组件来连接Mysql。JDBC, 全称为Java DataBase Connectivity standard。 加载依赖 其中包含 JDBC driver <dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.29&l…