经典卷积神经网络 - LeNet

image-20231022164234916

该模型用于手写的数字识别。
LeNet模型包含了多个卷积层和池化层,以及最后的全连接层用于分类。其中,每个卷积层都包含了一个卷积操作和一个非线性激活函数,用于提取输入图像的特征。池化层则用于缩小特征图的尺寸,减少模型参数和计算量。全连接层则将特征向量映射到类别概率上。

MNISt数据集

50000个训练数据,10000个测试数据。图像大小为28x28,共10类(0~9)。

  • LeNet是早期成功的神经网络
  • 先使用卷积层来学习图片空间信息
  • 然后使用全连接层来转换到类别空间

对于padding

通用的卷积时padding 的选择

如卷积核宽高为3时 padding 选择1

如卷积核宽高为5时 padding 选择2

如卷积核宽高为7时padding选择3

至于选择填充多少像素,通常有两个选择,分别叫做Valid卷积和Same卷积。

Valid卷积意味着不填充,这样的话,如果你有一个 n × n n\times n n×n的图像,用一个 f × f f\times f f×f的过滤器卷积,它将会给你一个 ( n − f + 1 ) × ( n − f + 1 ) (n-f+1)\times (n-f+1) (nf+1)×(nf+1)维的输出。例如,有一个6×6的图像,通过一个3×3的过滤器,得到一个4×4的输出。

Same卷积意味你填充后,你的输出大小和输入大小是一样的。根据这个公式 n − f + 1 n-f+1 nf+1,当你填充 p p p个像素点,n就变成了 n + 2 p n+2p n+2p,最后公式变为 n + 2 p − f + 1 n+2p-f+1 n+2pf+1。因此如果你有一个 n × n n\times n n×n的图像,用 p p p个像素填充边缘,输出的大小就是这样的 ( n + 2 p − f + 1 ) × ( n + 2 p − f + 1 ) (n+2p−f+1)\times (n+2p−f+1) (n+2pf+1)×(n+2pf+1)。如果你想让 ( n + 2 p − f + 1 ) = n (n+2p−f+1)=n (n+2pf+1)=n的话,使得输出和输入大小相等,如果你用这个等式求解 p p p,那么 p = ( f − 1 ) / 2 p=(f-1)/2 p=(f1)/2。所以当 f f f是一个奇数的时候,只要选择相应的填充尺寸,你就能确保得到和输入相同尺寸的输出。

代码实现

LeNet(LeNet-5)由两个部分组成:卷积编码器和全连接层密集块。

model.py

from torch import nnclass Reshape(nn.Module):def forward(self,x):return x.reshape((-1,1,28,28))class MyLeNet(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)# 假如sigmoid激活函数后 损失不下降self.model = nn.Sequential(Reshape(),nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Linear(120,84),nn.Linear(84,10))def forward(self,x):return self.model(x)

train.py

# 扫描数据次数
epochs = 20
# 分组大小
batch = 64
# 学习率
learning_rate = 0.05
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.MNIST(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")# 创建网络
net = MyLeNet()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/167903.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【IO】JavaIO流:字节流、字符流、缓冲流、转换流、序列化流等

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ IO流 Java IO1. 认识IO2. FileOutputStream(写…

【python】--python环境安装及配置

目录 一、python开发环境部署1、下载安装Miniconda2、python环境3、进入或退出python环境4、对应python环境安装工具/库5、进入pyhton环境,查看已安装的工具/库6、安装pycharm专业版7、pycharm创建项目并关联python版本环境 一、python开发环境部署 要安装一个pyth…

实战SRC

附言:从补天的公益src公司中选中了幸运儿。 1. 通过hunter鹰图平台搜索公司的相关资产,发现其采用了华途应用安全网关。 2.访问相关地址,尝试使用弱口令登录,发现直接利用admin/admin就登录了,可以看到后台的相关日志…

【项目实战】日志系统

目录 前言 整体架构 工具类的实现 日期类 文件类 判断文件存在 获取文件路径 创建目录 日志等级的规划 日志信息模块 消息格式化模块 格式化组件 抽象基类 派生子类 日期格式化子类 其他内容格式化子类 格式化类 根据字符创建不同对象 格式化字符串的解析 …

先后在影酷/传祺E9/昊铂GT量产交付,这家ADAS厂商何以领跑

智能泊车赛道正在迎来黄金增长期,以魔视智能为代表的玩家正在驶入大规模量产的“快车道”。 继在广汽传祺影酷、广汽传祺 E9实现规模化量产交付之后,魔视智能的Magic Parking智能泊车系列解决方案再度在广汽埃安旗下高端智能轿跑——昊铂GT上面实现量产…

Go 代码包与引入:如何有效组织您的项目

一、引言 在软件开发中,代码的组织和管理是成功项目实施的基础之一。特别是在构建大型、可扩展和可维护的应用程序时,这一点尤为重要。Go语言为这一需求提供了一个强大而灵活的工具:代码包(Packages)。代码包不仅允许…

Selenium+Pytest自动化测试框架详解

前言 selenium自动化 pytest测试框架 本章你需要 一定的python基础——至少明白类与对象,封装继承;一定的selenium基础——本篇不讲selenium,不会的可以自己去看selenium中文翻译网 一、测试框架简介 测试框架有什么优点 代码复用率高&…

Java 基础 面试 多线程

1.多线程 1.1 线程(Thread) 线程时一个程序内部的一条执行流程,java的main方法就是由一条默认的主线程执行 1.2 多线程 多线程是指从软硬件上实现的多条执行流程的技术(多条线程由CPU负责调度执行) 许多平台都离不开多…

【Nginx34】Nginx学习:安全链接、范围分片以及请求分流模块

Nginx学习:安全链接、范围分片以及请求分流模块 又迎来新的模块了,今天的内容不多,但我们都进行了详细的测试,所以可能看起来会多一点哦。这三个模块之前也从来都没用过,但是通过学习之后发现,貌似还都挺有…

前端react入门day01-了解react和JSX基础

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,请留下您的足迹) 目录 React介绍 React是什么 React的优势 React的市场情况 开发环境搭建 使用create-react-app快速搭建…

Qt窗体设计的布局

本文介绍Qt窗体的布局。 Qt窗体的布局分为手动布局和自动布局,手动布局即靠手工排布各控件的位置。而自动布局则是根据选择的布局类型自动按此类型排布各控件的位置,使用起来比较方便,本文主要介绍Qt的自动布局。 1.垂直布局 垂直布局就是…

看微功耗遥测终端机如何轻松应对野外环境挑战?

在野外,数据的实时监测和传输是至关重要的。无论是环境温度、湿度,还是水位、流量,都需要精准把控。然而,传统的监测方法往往受限于电源供应问题,而无法充分发挥其功能。这时候,一款微功耗遥测终端机&#…

Zabbix“专家坐诊”第207期问答汇总

问题一 Q:不小心把host表删除了,怎么处理?现在使用的zabbix 4.0.3的server,agent是4.2.1,能不能不动agent的情况下升级server版本,重新部署? A:数据库有备份话恢复即可,…

在线零售多用户多门店连锁商城系统

在线零售多用户商城系统和多门店连锁商城系统的核心都是线上线下相结合的,线上和线下结合,一体化是在线新零售多用户商城系统发展的趋势,现在移动互联网时代,越来越多的传统企业,如:连锁店铺,连…

SpringBoot中的日志使用

SpringBoot的默认使用 观察SpringBoot的Maven依赖图 可以看出来,SpringBoot默认使用的日志系统是使用Slf4j作为门户,logback作为日志实现 编写一个测试代码看是否是这样 SpringBootTest class SpringbootLogDemoApplicationTests {//使用Slf4j来创建LOG…

Android音视频开发之基础知识

一、视频文件 1、视频格式 常见格式:mp4、mkv、flv 封装的数据:音频码流、视频码流 常用工具: [FFmpeg下载]:https://ffmpeg.org/download.html 下载、安装并配置环境变量 ffmpeg.exe 视频编解码 ffplay.exe 播放器库 ffprobe.exe 音视频分…

17-spring aop调用过程概述

文章目录 1.源码2. debug过程 1.源码 public class TestAop {public static void main(String[] args) throws Exception {saveGeneratedCGlibProxyFiles(System.getProperty("user.dir") "/proxy");ApplicationContext ac new ClassPathXmlApplication…

【TES605】基于Virtex-7 FPGA的高性能实时信号处理平台

板卡概述 TES605是一款基于Virtex-7 FPGA的高性能实时信号处理平台,该平台采用1片TI的KeyStone系列多核DSP TMS320C6678作为主处理单元,采用1片Xilinx的Virtex-7系列FPGA XC7VX690T作为协处理单元,具有2个FMC子卡接口,各个处理节…

【PyTorch】深度学习实践 02 线性模型

深度学习的准备过程 准备数据集选择模型模型训练进行推理预测 问题 对某种产品花费 x 个工时,即可得到 y 收益,现有 x 和 y 的对应表格如下: x (hours) y(points)12243648 求花费4个工时可得…

Power BI 傻瓜入门 5. 准备数据源

本章内容将介绍: 定义Power BI支持的数据源类型探索如何在Power BI中连接和配置数据源了解选择数据源的最佳做法 现代组织有很多数据。因此,不用说,微软等企业软件供应商已经构建了数据源连接器,以帮助组织将数据导入Power BI等…