LeNet网络复现

文章目录

  • 1. LeNet历史背景
    • 1.1 早期神经网络的挑战
    • 1.2 LeNet的诞生背景
  • 2. LeNet详细结构
    • 2.1 总览
    • 2.2 卷积层与其特点
    • 2.3 子采样层(池化层)
    • 2.4 全连接层
    • 2.5 输出层及激活函数
  • 3. LeNet实战复现
    • 3.1 模型搭建model.py
    • 3.2 训练模型train.py
    • 3.3 测试模型test.py
  • 4. LeNet的变种与实际应用
    • 4.1 LeNet-5及其优化
    • 4.2 从LeNet到现代卷积神经网络

1. LeNet历史背景

1.1 早期神经网络的挑战

早期神经网络面临了许多挑战。首先,它们经常遇到训练难题,例如梯度消失和梯度爆炸,特别是在使用传统激活函数如sigmoid或tanh时。另外,当时缺乏大规模、公开的数据集,导致模型容易过拟合并且泛化性能差。再者,受限于当时的计算资源,模型的大小和训练速度受到了很大的限制。最后,由于深度学习领域还处于萌芽阶段,缺少许多现代技术来优化和提升模型的表现。

1.2 LeNet的诞生背景

LeNet的诞生背景是为了满足20世纪90年代对手写数字识别的实际需求,特别是在邮政和银行系统中。Yann LeCun及其团队意识到,对于图像这种有结构的数据,传统的全连接网络并不是最佳选择。因此,他们引入了卷积的概念,设计出了更适合图像处理任务的网络结构,即LeNet。

2. LeNet详细结构

2.1 总览

在这里插入图片描述

2.2 卷积层与其特点

卷积层是卷积神经网络(CNN)的核心。这一层的主要目的是通过卷积操作检测图像中的局部特征。

特点:

  1. 局部连接性: 卷积层的每个神经元不再与前一层的所有神经元相连接,而只与其局部区域相连接。这使得网络能够专注于图像的小部分,并检测其中的特征。
  2. 权值共享: 在卷积层中,一组权值在整个输入图像上共享。这不仅减少了模型的参数,而且使得模型具有平移不变性。
  3. 多个滤波器: 通常会使用多个卷积核(滤波器),以便在不同的位置检测不同的特征。

2.3 子采样层(池化层)

池化层是卷积神经网络中的另一个关键组件,用于缩减数据的空间尺寸,从而减少计算量和参数数量。

主要类型:

  1. 最大池化(Max pooling): 选择覆盖区域中的最大值作为输出。
  2. 平均池化(Average pooling): 计算覆盖区域的平均值作为输出。

2.4 全连接层

在卷积神经网络的最后,经过若干卷积和池化操作后,全连接层用于将提取的特征进行“拼接”,并输出到最终的分类器。

特点:

  1. 完全连接: 全连接层中的每个神经元都与前一层的所有神经元相连接。
  2. 参数量大: 由于全连接性,此层通常包含网络中的大部分参数。
  3. 连接多个卷积或池化层的特征: 它的主要目的是整合先前层中提取的所有特征。

2.5 输出层及激活函数

输出层:
输出层是神经网络的最后一层,用于输出预测结果。输出的数量和类型取决于特定任务,例如,对于10类分类任务,输出层可能有10个神经元。

激活函数:
激活函数为神经网络提供了非线性,使其能够学习并进行复杂的预测。

  1. Sigmoid: 取值范围为(0, 1)。
  2. Tanh: 取值范围为(-1, 1)。
  3. ReLU (Rectified Linear Unit): 最常用的激活函数,将所有负值置为0。
  4. Softmax: 常用于多类分类的输出层,它返回每个类的概率。

3. LeNet实战复现

3.1 模型搭建model.py

import torch
from torch import nn# 自定义网络模型
class LeNet(nn.Module):# 1. 初始化网络(定义初始化函数)def __init__(self):super(LeNet, self).__init__()# 定义网络层self.Sigmoid = nn.Sigmoid()self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)# 展开self.flatten = nn.Flatten()self.f6 = nn.Linear(120, 84)self.output = nn.Linear(84, 10)# 2. 前向传播网络def forward(self, x):x = self.Sigmoid(self.c1(x))x = self.s2(x)x = self.Sigmoid(self.c3(x))x = self.s4(x)x = self.c5(x)x = self.flatten(x)x = self.f6(x)x = self.output(x)return xif __name__ == "__main__":x = torch.rand([1, 1, 28, 28])model = LeNet()y = model(x)

3.2 训练模型train.py

import torch
from torch import nn
from model import LeNet
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 数据转换为tensor格式
data_transformer = transforms.Compose([transforms.ToTensor()
])# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)# 定义一个损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()# 定义一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 学习率每隔10轮, 变换原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0for batch, (x, y) in enumerate(dataloader):# 前向传播x, y = x.to(device), y.to(device)output = model(x)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred)/output.shape[0]optimizer.zero_grad()cur_loss.backward()optimizer.step()loss += cur_loss.item()current += cur_acc.item()n = n + 1print("train_loss" + str(loss/n))print("train_acc" + str(current/n))# 定义测试函数
def val(dataloader, model, loss_fn):model.eval()loss, current, n = 0.0, 0.0, 0with torch.no_grad():for batch, (x, y) in enumerate(dataloader):# 前向传播x, y = x.to(device), y.to(device)output = model(x)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print("val_loss" + str(loss / n))print("val_acc" + str(current / n))return current/n# 开始训练
epoch = 50
min_acc = 0for t in range(epoch):print(f'epoch{t+1}\n-----------------------------------------------------------')train(train_dataloader, model, loss_fn, optimizer)a = val(test_dataloader, model, loss_fn)# 保存最好的模型权重if a > min_acc:folder = "sava_model"if not os.path.exists(folder):os.mkdir("sava_model")min_acc = aprint("sava best model")torch.save(model.state_dict(), 'sava_model/best_model.pth')
print('Done!')

3.3 测试模型test.py

import torch
from model import LeNet
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage# 数据转换为tensor格式
data_transformer = transforms.Compose([transforms.ToTensor()
])# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)# 加载模型权重并设置为评估模式
model.load_state_dict(torch.load("./sava_model/best_model.pth"))
model.eval()# 获取结果
classes = ["0","1","2","3","4","5","6","7","8","9",
]# 把tensor转换为图片, 方便可视化
show = ToPILImage()# 进入验证
for i in range(5):x, y = test_dataset[i]show(x).show()x = torch.unsqueeze(x, dim=0).float().to(device)with torch.no_grad():pred = model(x)predicted, actual = classes[torch.argmax(pred[0])], classes[y]print(f'Predicted: {predicted}, Actual: {actual}')

4. LeNet的变种与实际应用

4.1 LeNet-5及其优化

LeNet-5是由Yann LeCun于1998年设计的,并被广泛应用于手写数字识别任务。它是卷积神经网络的早期设计之一,主要包含卷积层、池化层和全连接层。

结构:

  1. 输入层:接收32×32的图像。
  2. 卷积层C1:使用5×5的滤波器,输出6个特征图。
  3. 池化层S2:2x2的平均池化。
  4. 卷积层C3:使用5×5的滤波器,输出16个特征图。
  5. 池化层S4:2x2的平均池化。
  6. 卷积层C5:使用5×5的滤波器,输出120个特征图。
  7. 全连接层F6。
  8. 输出层:10个单元,对应于0-9的手写数字。
    激活函数:Sigmoid或Tanh。

优化:

  1. ReLU激活函数: 原始的LeNet-5使用Sigmoid或Tanh作为激活函数,但现代网络更喜欢使用ReLU,因为它的训练更快,且更少受梯度消失的影响。
  2. 更高效的优化算法: 如Adam或RMSProp,它们通常比传统的SGD更快、更稳定。
  3. 批量归一化: 加速训练,并提高模型的泛化能力。
  4. Dropout: 在全连接层中引入Dropout可以增强模型的正则化效果。

4.2 从LeNet到现代卷积神经网络

从LeNet-5开始,卷积神经网络已经经历了巨大的发展。以下是一些重要的里程碑:

  1. AlexNet (2012): 在ImageNet竞赛中取得突破性的成功。它具有更深的层次,使用ReLU激活函数,以及Dropout来防止过拟合。
  2. VGG (2014): 由于其统一的结构(仅使用3×3的卷积和2x2的池化)而闻名,拥有多达19层的版本。
  3. GoogLeNet/Inception (2014): 引入了Inception模块,可以并行执行多种大小的卷积。
  4. ResNet (2015): 引入了残差块,使得训练非常深的网络变得可能。通过这种方式,网络可以达到上百甚至上千的层数。
  5. DenseNet (2017): 每层都与之前的所有层连接,导致具有非常稠密的特征图。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/145463.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatisPlus(七)等值查询

等值查询 条件查询&#xff1a;使用 Wrapper 对象&#xff0c;传递查询条件。 QueryWrapper&#xff08;不要使用&#xff09; 代码 Testvoid eq() {QueryWrapper<User> wrapper new QueryWrapper<>();wrapper.eq("name", "张三");List<…

httpserver 下载服务器demo

实现效果如下&#xff1a; 图片可以直接显示 cpp h 这些可以直接显示 其他的 则是提示是否要下载 单线程 还有bug 代码如下 先放上来 #include "httpserver.h" #include "stdio.h" #include <stdlib.h> #include <arpa/inet.h> #include…

Vue控制textarea可输入行数限制-案例

控制只能输入六行内容 UI部分代码 //我使用了antd ui库 <a-form-model-item ref"address_group" label"规则描述" prop"address_group" > <a-textarea:rows"6"style"width: 60%"placeholder"一次最多输入6行…

【数据结构】队列和栈

大家中秋节快乐&#xff0c;玩了好几天没有学习&#xff0c;今天分享的是栈以及队列的相关知识&#xff0c;以及栈和队列相关的面试题 1.栈 1.1栈的概念及结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作…

MySQL数据查询性能如何分析--Explain介绍说明

1、Explain是什么 Explain是MySQL执行查看执行计划命令的指令&#xff0c;使用EXPLAIN关键字可以模拟优化器执行SQL查询语句&#xff0c;从而知道MySQL是如何处理你的SQL语句的。分析你的查询语句或是表结构的性能瓶颈。 2、Explain官网介绍 http://dev.mysql.com/doc/refma…

【MySQL】数据类型(二)

文章目录 一. char字符串类型二. varchar字符串类型2.1 char和varchar比较 三. 日期和时间类型四. enum和set类型4.1 set的查询 结束语 一. char字符串类型 char (L) 固定长度字符串 L是可以存储的长度&#xff0c;单位是字符&#xff0c;最大长度是255 MySQL中的字符&#xff…

Makefile学习

一、Makefile的介绍 1.1 什么是Makefile 相信在Linux系统中经常会用到make这个命令来编译程序&#xff0c;而执行make命令所依赖的文件便是Makefile文件&#xff0c;make命令通过Makefile文件编写的内容对程序进行编译。make命令根据文件更新的时间戳来决定哪些文件需要重新编…

纯css html 真实水滴效果

惯例,不多说直接上图 秉承着开源精神,我们将这段代码无私地分享给大家&#xff0c;因为我们深信&#xff0c;信息的共享和互相学习是推动科技进步的关键。我们鼓励大家在使用这段代码的同时&#xff0c;也能够将其中的原理、思想和经验分享给更多的人。 这份代码是我们团队用心…

一百八十六、大数据离线数仓完整流程——步骤五、在Hive的DWS层建动态分区表并动态加载数据

一、目的 经过6个月的奋斗&#xff0c;项目的离线数仓部分终于可以上线了&#xff0c;因此整理一下离线数仓的整个流程&#xff0c;既是大家提供一个案例经验&#xff0c;也是对自己近半年的工作进行一个总结。 二、数仓实施步骤 &#xff08;五&#xff09;步骤五、在Hive的…

Purism 推出注重隐私的 Linux 平板电脑

导读一款昂贵的 Linux 平板电脑&#xff0c;注重安全和隐私。让我们拭目以待。 Purism 是一家日益流行的计算机硬件产品制造商&#xff0c;专门提供配备注重隐私的开源 Linux 发行版的笔记本电脑、台式机和移动设备。 最近&#xff0c;他们发布了一款新产品 Librem 11 平板电…

ARM底层汇编基础指令

汇编语言的组成 伪操作 不参与程序执行&#xff0c;但是用于告诉编译器程序怎么编译.text .global .end .if .else .endif .data 汇编指令 编译器将一条汇编指令编译成一条机器码&#xff0c;在内存里一条指令占4字节内存&#xff0c;一条指令可以实现一个特定的功能 伪指令 不…

嵌入式Linux应用开发-基础知识-第十六章GPIO和Pinctrl子系统的使用

嵌入式Linux应用开发-基础知识-第十六章GPIO和Pinctrl子系统的使用 第十六章 GPIO 和 Pinctrl 子系统的使用16.1 Pinctrl 子系统重要概念16.1.1 引入16.1.2 重要概念16.1.3 示例16.1.4 代码中怎么引用pinctrl 16.2 GPIO子系统重要概念16.2.1 引入16.2.2 在设备树中指定引脚16.2…

软件设计模式系列之二十一——观察者模式

1 观察者模式的定义 观察者模式&#xff08;Observer Pattern&#xff09;是一种行为型设计模式&#xff0c;它允许对象之间建立一对多的依赖关系&#xff0c;当一个对象的状态发生变化时&#xff0c;所有依赖于它的对象都会得到通知并自动更新。这个模式也被称为发布-订阅模式…

数据结构 | 二叉树

基本形状 可参照 数据结构&#xff1a;树(Tree)【详解】_数据结构 树_UniqueUnit的博客-CSDN博客 二叉树的性质 三种顺序遍历

BUUCTF reverse wp 51 - 55

findKey shift f12 找到一个flag{}字符串, 定位到关键函数, F5无效, 大概率是有花指令, 读一下汇编 这里连续push两个byte_428C54很奇怪, nop掉下面那个, 再往上找到函数入口, p设置函数入口, 再F5 LRESULT __stdcall sub_401640(HWND hWndParent, UINT Msg, WPARAM wPara…

Kafka(一)使用Docker Compose安装单机Kafka以及Kafka UI

文章目录 Kafka中涉及到的术语Kafka镜像选择Kafka UI镜像选择Docker Compose文件Kafka配置项说明KRaft vs Zookeeper和KRaft有关的配置关于Controller和Broker的概念解释Listener的各种配置 Kafka UI配置项说明 测试Kafka集群Docker Compose示例配置 Kafka中涉及到的术语 对于…

Spring5应用之AOP切入点详解

作者简介&#xff1a;☕️大家好&#xff0c;我是Aomsir&#xff0c;一个爱折腾的开发者&#xff01; 个人主页&#xff1a;Aomsir_Spring5应用专栏,Netty应用专栏,RPC应用专栏-CSDN博客 当前专栏&#xff1a;Spring5应用专栏_Aomsir的博客-CSDN博客 文章目录 前言切入点详解切…

简单三步 用GPT-4和Gamma自动生成PPT PDF

1. 用GPT-4 生产PPT内容 我想把下面的文章做成PPT&#xff0c;请你给出详细的大纲和内容 用于谋生的知识&#xff0c;学生主要工作是学习&#xff0c;成年人的工作是养家糊口&#xff0c;这是基本的要求&#xff0c;在这之上&#xff0c;才能有更高的追求。 不要短期期望过高…

C#(CSharp)入门教程

目录 C#的第一个程序 变量 折叠代码 变量类型和声明变量 获取变量类型所占内存空间&#xff08;sizeof&#xff09; 常量 转义字符 隐式转换 显示转换 异常捕获 运算符 算术运算符 布尔逻辑运算符 关系运算符 位运算符 其他运算符 字符串拼接 …

计算机图像处理-高斯滤波

高斯滤波 高斯滤波是一种线性平滑滤波&#xff0c;适用于消除高斯噪声&#xff0c;广泛应用于图像处理的减噪过程。通俗的讲&#xff0c;高斯滤波就是对整幅图像进行加权平均的过程&#xff0c;每一个像素点的值&#xff0c;都由其本身和邻域内的其他像素值经过加权平均后得到…