LeNet网络分析与demo实例

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

网络分析:

最好是把这个图像和代码对着来看然后进行分析的时候比较快

# 使用torch.nn包来构建神经网络.
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module): 					# 继承于nn.Module这个父类def __init__(self):						# 初始化网络结构super(LeNet, self).__init__()    	# 多继承需用到super函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):			 # 正向传播过程x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

Conv1:

输入矩阵  : 32*32*3

卷积层:5*5*3   16个

输出:(32-5)/1+1 = 28    28*28*16

MaxPool:

输入矩阵:28*28*16

2*2最大下采样

输出:14*14*16

Conv2:

输入矩阵  : 14*14*16

卷积层:5*5*16   32个

输出:(14-5)/1+1 = 10    10*10*32

MaxPool:

输入矩阵:10*10*32

2*2最大下采样

输出:5*5*32

全连接Linear:

Linear(32*5*5,120)

Linear(120,84)

Linear(120,10)最后这个数字要取决于你要分几类

经卷积后的输出层尺寸计算公式为:

Output= (W−F+2P)/S​+1

输入图片大小 W×W(一般情况下Width=Height)
Filter大小 F×F
步长 S
padding的像素数 P

经过上述分析就可以pytorch构建网络了:

import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Reshape(nn.Module):def forward(self,x):return x.view(-1,1,28,28)class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)#输入数据的维度 卷积核的个数(即输出数据的维度) 卷积核的大小5*5self.pool1 = nn.MaxPool2d(2,2)#池化层的大小,步长self.conv2 = nn.Conv2d(6,16,5)self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):x = F.relu(self.conv1(x))       #input(1,28,28) output(6,28,28)x = self.pool1(x)               # output(6,14,14) 池化层不改变数据的维度 (w - f + 2*p)/s+1 w为输入图片大小,f为卷积核的大小,p为填充否,s为步长 计算可得输出图片的大小x = F.relu(self.conv2(x))       #output(16,10,10)x = self.pool2(x)               #output(16,5,5)x = x.view(-1,16*5*5)           #output(16*5*5) 展成一列向量进行操作x = F.relu(self.fc1(x))         #output(120)x = F.relu(self.fc2(x))         #output(84)x = F.relu(self.fc3(x))         #output(10)return ximport torch
input1 = torch.rand([32,1,28,28])
model = LeNet()
print(model)
output = model(input1)

数据集介绍

利用torchvision.datasets函数可以在线导入pytorch中的数据集,包含一些常见的数据集如MNIST等

# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,	# 表示是数据集中的测试集download=False,transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set, batch_size=10000, # 每批用于验证的样本数shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

2.2 训练过程

epoch   : 对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch   : 由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step :   对一个batch的数据训练的过程称为 一个 iteration 或 step

训练过程

net = LeNet()						  				# 定义训练的网络模型
loss_function = nn.CrossEntropyLoss() 				# 定义损失函数为交叉熵损失函数 
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练running_loss = 0.0time_start = time.perf_counter()for step, data in enumerate(train_loader, start=0):   # 遍历训练集,step从0开始计算inputs, labels = data 	# 获取训练集的图像和标签optimizer.zero_grad()   # 清除历史梯度# forward + backward + optimizeoutputs = net(inputs)  				  # 正向传播loss = loss_function(outputs, labels) # 计算损失loss.backward() 					  # 反向传播optimizer.step() 					  # 优化器更新参数# 打印耗时、损失、准确率等数据running_loss += loss.item()if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用outputs = net(test_image) 				 # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出accuracy = (predict_y == test_label).sum().item() / test_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy(epoch + 1, step + 1, running_loss / 500, accuracy))print('%f s' % (time.perf_counter() - time_start))        # 打印耗时running_loss = 0.0print('Finished Training')# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)


 

测试:

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('horse.jpg')
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width] # 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))# 预测
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

输出即为预测的标签。

其实预测结果也可以用 softmax 表示,输出10个概率:

with torch.no_grad():outputs = net(im)predict = torch.softmax(outputs, dim=1)
print(predict)

输出结果中最大概率值对应的索引即为 预测标签 的索引。

tensor([[2.2782e-06, 2.1008e-07, 1.0098e-04, 9.5135e-05, 9.3220e-04, 2.1398e-04,3.2954e-08, 9.9865e-01, 2.8895e-08, 2.8820e-07]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/223554.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【模式识别】探秘分类奥秘:最近邻算法解密与实战

​🌈个人主页:Sarapines Programmer🔥 系列专栏:《模式之谜 | 数据奇迹解码》⏰诗赋清音:云生高巅梦远游, 星光点缀碧海愁。 山川深邃情难晤, 剑气凌云志自修。 目录 🌌1 初识模式识…

【性能优化】MySql数据库查询优化方案

阅读本文你的收获 了解系统运行效率提升的整体解决思路和方向学会MySQl中进行数据库查询优化的步骤学会看慢查询、执行计划、进行性能分析、调优 一、问题:如果你的系统运行很慢,你有什么解决方案? ​关于这个问题,我们通常首先…

什么是动态代理?

目录 一、为什么需要代理? 二、代理长什么样? 三、Java通过什么来保证代理的样子? 四、动态代理实现案例 五、动态代理在SpringBoot中的应用 导入依赖 数据库表设计 OperateLogEntity实体类 OperateLog枚举 RecordLog注解 上下文相…

Python 运算符 算数运算符 关系运算符 赋值运算符 逻辑运算 (逻辑运算符的优先级) 位运算 成员运算符 身份运算符 运算符的优先级

1 运算符算数运算符关系运算符赋值运算符逻辑运算逻辑运算符的优先级 位运算布尔运算符移位运算符 成员运算符身份运算符运算符的优先级 运算符 算数运算符 四则运算 - * / a 8 b 9 print(ab)#与Java类似 也可以进行字符串的连接 注意:字符串数字字符串 不存在会抛出异常…

排序算法——桶排序

把数据放进若干个桶,然后在桶里用其他排序,近乎分治思想。从数值的低位到高位依次排序,有几位就排序几次。例如二位数就排两次,三位数就排三次,依次按照个十百...的顺序来排序。 第一次排序:50 12 …

MongoDB安装部署

二、安装部署 2.1 下载 下载地址:MongoDB Enterprise Server Download | MongoDB 当前最新版本6.0.9,5.0.9对Mac m1需要centos 8.2版本。选择docker安装。 2.2 docker-ce安装 # 安装docker # 默认repo源没有docker-ce安装包,需要新的rep…

【飞凌 OK113i-C 全志T113-i开发板】一些有用的常用的命令测试

一些有用的常用的命令测试 一、系统信息查询 可以查询板子的内核信息、CPU处理器信息、环境变量等 二、CPU频率 从上面的系统信息查询到,这是一颗具有两个ARMv7结构A7内核的处理器,主频最高1.2GHz 可以通过命令查看当前支持的频率以及目前所使用主频 …

面向对象设计与分析40讲(12)简单工厂方法模式

文章目录 定义示例优缺点 定义 简单工厂模式是一种创建型模式,用于根据客户端的需求创建对象实例,所谓的需求反映到编程语言里就是传入的参数。 简单工厂模式包括三个主要部分: 工厂类(Simple Factory):…

生物系统学中的进化树构建和分析R工具包V.PhyloMaker2的介绍和详细使用

V.PhyloMaker2是一个R语言的工具包,专门用于构建和分析生物系统学中的进化树(也称为系统发育树或phylogenetic tree)。以下是对V.PhyloMaker2的一些基本介绍和使用说明: 论文介绍:V.PhyloMaker2: An updated and enla…

持续集成交付CICD:Linux 部署 Jira 9.12.1

目录 一、实验 1.环境 2.K8S master节点部署Jira 3.Jira 初始化设置 4.Jira 使用 一、实验 1.环境 (1)主机 表1 主机 主机架构版本IP备注master1K8S master节点1.20.6192.168.204.180 jenkins slave (从节点) jira9.12.1…

【微服务】:微服务最佳实践

关键需求 最大限度地提高团队的自主性:创建一个团队可以完成更多工作而不必与其他团队协调的环境。 优化开发速度:硬件便宜,人不是。使团队能够轻松快捷地构建强大的服务。 关注自动化:人们犯错误。更多的系统操作也意味着更多的…

QT编写应用的界面自适应分辨率的解决方案

博主在工作机上完成QT软件开发(控件大小与字体大小比例正常),部署到客户机后,发现控件大小与字体大小比例失调,具体表现为控件装不下字体,即字体显示不全,推测是软件不能自适应分辨率导致的。 文…

Hadoop入门学习笔记——六、连接到Hive

视频课程地址:https://www.bilibili.com/video/BV1WY4y197g7 课程资料链接:https://pan.baidu.com/s/15KpnWeKpvExpKmOC8xjmtQ?pwd5ay8 Hadoop入门学习笔记(汇总) 目录 六、连接到Hive6.1. 使用Hive的Shell客户端6.2. 使用Beel…

Confluent 与阿里云将携手拓展亚太市场,提供消息流平台服务

10 月 31 日,杭州云栖大会上,阿里云云原生应用平台负责人丁宇宣布,Confluent 成为阿里云技术合作伙伴,合作全新升级,一起拓展和服务亚太市场。 本次合作伙伴签约,阿里云与消息流开创领导者 Confluent 将进一…

每次maven刷新jdk都要重新设置

pom.xml <java.version>17</java.version> 改为<java.version>1.8</java.version>

MATLAB - 四元数(quaternion)

系列文章目录 前言 一、简介 四元数是一种四元超复数&#xff0c;用于三维旋转和定向。 四元数的表示形式为 abicjdk&#xff0c;其中 a、b、c 和 d 为实数&#xff0c;i、j 和 k 为基元&#xff0c;满足等式&#xff1a;i2 j2 k2 ijk -1。 四元数集用 H 表示&#xff0c…

推荐给前端开发的 5 款 Chrome 扩展

工欲善其事&#xff0c;必先利其器。Chrome 可能是前端开发中使用最多的浏览器。在日常开发中&#xff0c;下列几款 Chrome 扩展也许能让你的开发工作事半功倍 &#x1f680; Vue.js devtools ⚙️ vue 官方专为 vue 应用开发的调试工具。 通过使用它&#xff0c;你可以快速查看…

redis 从0到1完整学习 (五):集合 IntSet 数据结构

文章目录 1. 引言2. redis 源码下载3. IntSet 数据结构4. 参考 1. 引言 前情提要&#xff1a; 《redis 从0到1完整学习 &#xff08;一&#xff09;&#xff1a;安装&初识 redis》 《redis 从0到1完整学习 &#xff08;二&#xff09;&#xff1a;redis 常用命令》 《redi…

智能优化算法应用:基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.鹈鹕算法4.实验参数设定5.算法结果6.参考文献7.MA…

微服务之配置中心与服务跟踪

zookeeper 配置中心 实现的架构图如下所示&#xff0c;采取数据加载到内存方式解决高效获取的问题&#xff0c;借助 zookeeper 的节点监听机制来实现实时感知。 配置中心数据分类 事件调度&#xff08;kafka&#xff09; 消息服务和事件的统一调度&#xff0c;常用用 kafka …