CNN——LeNet

1.LeNet概述       

         LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。

        当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。

        下面是整个网络的结构图

        LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。

        上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。

2.详解LeNet

下面对图中每一层做详细的介绍:

  • LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
  • 激活函数为Sigmoid
  • 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充

        input输入层,尺寸为1*32*32的二值图

        C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。

        S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。

        C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。

        S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图

        C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。

        F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。

        输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现

3.使用LeNet实现Mnist数据集分类

1.导入所需库

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条

2.使用GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'

3.读取Mnist数据集

# 定义数据转换以进行数据标准化
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
])# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

4.搭建LeNet

        需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可,在训练时使用torch.nn.CrossEntropyLoss

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # Mnist尺寸为28*28,这里设置填充变成32*32self.sigmoid = nn.Sigmoid()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.flatten = nn.Flatten()self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(self.sigmoid(self.conv1(x)))x = self.pool(self.sigmoid(self.conv2(x)))x = self.flatten(x)x = self.sigmoid(self.fc1(x))x = self.sigmoid(self.fc2(x))x = self.fc3(x)return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))

5.训练函数

def train(model, lr, epochs):# 将模型放入GPUmodel = model.to(device)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss().to(device)# SGDoptimizer = torch.optim.SGD(model.parameters(), lr=lr)# 记录训练与验证数据train_losses = []train_accuracies = []# 开始迭代   for epoch in range(epochs):   # 切换训练模式model.train()  # 记录变量train_loss = 0.0correct_train = 0total_train = 0# 读取训练数据并使用 tqdm 显示进度条for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):# 训练数据移入GPUinputs = inputs.to(device)targets = targets.to(device)# 模型预测outputs = model(inputs)# 计算损失loss = loss_fn(outputs, targets)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 使用优化器优化参数optimizer.step()# 记录损失train_loss += loss.item()# 计算训练正确个数_, predicted = torch.max(outputs, 1)total_train += targets.size(0)correct_train += (predicted == targets).sum().item()# 计算训练正确率并记录train_loss /= len(train_dataloader)train_accuracy = correct_train / total_traintrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 输出训练信息print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")# 绘制损失和正确率曲线plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(epochs), train_losses, label='Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(range(epochs), train_accuracies, label='Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

6.模型训练

model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)

7.模型测试 

def test(model, test_dataloader, device, model_path):# 将模型设置为评估模式model.eval()# 将模型移动到指定设备上model.to(device)# 从给定路径加载模型的状态字典model.load_state_dict(torch.load(model_path))correct_test = 0total_test = 0# 不计算梯度with torch.no_grad():# 遍历测试数据加载器for inputs, targets in test_dataloader:  # 将输入数据和标签移动到指定设备上inputs = inputs.to(device)targets = targets.to(device)# 模型进行推理outputs = model(inputs)# 获取预测结果中的最大值_, predicted = torch.max(outputs, 1)total_test += targets.size(0)# 统计预测正确的数量correct_test += (predicted == targets).sum().item()# 计算并打印测试数据的准确率test_accuracy = correct_test / total_testprint(f"Accuracy on Test: {test_accuracy:.4f}")return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/229871.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件测试之白盒测试

概念与定义 白盒测试:侧重于系统或部件内部机制的测试,类型分为分支测试(判定节点测试)、路径测试、语句测试。 控制流分析(基于程序结构):控制流分析是一类用于分析程序控制流结构的静态分析技术,目的在于…

约束满足问题改进技术:基于变量和赋值次序的启发式

回溯搜索的通用算法的问题与改进思路 • 需改善无信息回溯搜索算法的性能。 • 通用改进方法的思路: – 下一步该给哪个变量赋值, 按什么顺序给该变量赋值? – 每步搜索应该做怎样的推理? 当前变量的赋值会对其他未赋值变量产…

【SpringBoot框架篇】34.使用Spring Retry完成任务的重试

文章目录 简要1.为什么需要重试?2.添加maven依赖3.使用Retryable注解实现重试4.基于RetryTemplate模板实现重试 简要 Spring实现了一套重试机制,功能简单实用。Spring Retry是从Spring Batch独立出来的一个功能,已经广泛应用于Spring Batch,…

算法巡练day04Leetcode24交换节点19删除倒数节点142环形链表

今天学习的文章和视频链接 https://www.bilibili.com/video/BV1YT411g7br/?vd_source8272bd48fee17396a4a1746c256ab0ae https://www.bilibili.com/video/BV1if4y1d7ob/?vd_source8272bd48fee17396a4a1746c256ab0ae 24两两交换链表中的节点 给你一个链表,两两…

ASP.NET Core基础之图片文件(一)-WebApi图片文件上传到文件夹

阅读本文你的收获: 了解WebApi项目保存上传图片的三种方式学习在WebApi项目中如何上传图片到指定文件夹中 在ASP.NET Core基础之图片文件(一)-WebApi访问静态图片文章中,学习了如何获取WebApi中的静态图片,本文继续分享如何上传图片。 那么…

八皇后问题(C语言/C++)超详细讲解/由浅入深---深入八皇后问题

介绍引入 在计算机科学中,八皇后问题是一个经典的回溯算法问题。这个问题的目标是找出一种在8x8国际象棋棋盘上放置八个皇后的方法,使得没有任何两个皇后能够互相攻击。换句话说,每一行、每一列以及对角线上只能有一个皇后。 想象一下&…

为什么大学c语言课不顺便教一下Linux,Makefile

为什么大学c语言课不顺便教一下Linux,Makefile,git,gdb等配套工具链呢? 在开始前我有一些资料,是我根据自己从业十年经验,熬夜搞了几个通宵,精心整理了一份「Linux的资料从专业入门到高级教程工具包」&…

Docker 网络管理

一、Docker网络简介 Docker网络是容器化应用程序的重要组成部分,它使得容器之间可以互相通信和连接,同时也提供了容器与外部环境之间的隔离和连接。 二、Docker网络网络模式 Docker 提供了多种网络模式,可以通过docker network ls 命令查看…

MySQL——事物

目录 一.发现问题 二.什么时事物 三.事务提交方式 四.事物的常规操作方式 五. 事务隔离级别 1.如何理解隔离性 2.隔离级别 3.查看与设置隔离性 4.读未提交【Read Uncommitted】 5.读提交【Read Committed】 6.可重复读【Repeatable Read】 7.串行化【serializabl…

Unity游戏资源更新(AB包)

目录 前言: 一、什么是AssetBundle 二、AssetBudle的基本使用 1.AssetBundle打包 2.BuildAssetBundle BuildAssetBundleOptions BuildTarget 示例 3.AssetBundle的加载 LoadFromFile LoadFromMemory LoadFromMemoryAsync UnityWebRequestAsssetBundle 前…

QProgressDialog用法及结合QThread用法,四种线程使用

1 QProgressDialog概述 QProgressDialog类提供耗时操作的进度条。 进度对话框用于向用户指示操作将花费多长时间,并演示应用程序没有冻结。此外,QPorgressDialog还可以给用户一个中止操作的机会。 进度对话框的一个常见问题是很难知道何时使用它们;操作…

Linux shell编程学习笔记38:history命令

目录 0 前言 1 history命令的功能、格式和退出状态1.1 history命令的功能1.2 history命令的格式1.3退出状态2 命令应用实例2.1 history:显示命令历史列表2.2 history -a:将当前会话的命令行历史追加到历史文件~/.bash_history中2.3 history -c&#xf…

如何做好档案数字化前的鉴定工作

要做好档案数字化前的鉴定工作,可以按照以下步骤进行: 1. 确定鉴定目标:明确要鉴定的档案的内容、数量和性质,确定鉴定的范围和目标。 2. 进行档案清点:对档案进行全面清点和登记,包括数量、种类、状况等信…

【Linux】基本指令了解(一)

💗个人主页💗 ⭐个人专栏——数据结构学习⭐ 💫点击关注🤩一起学习C语言💯💫 目录 导读:1. 认识Linux1.1 什么是Linux1.2 Linux特点 2. ls指令3. pwd命令4. cd 指令5. touch命令6. mkdir指令7. …

<JavaEE> TCP 的通信机制(二) -- 连接管理(三次握手和四次挥手)

目录 TCP的通信机制的核心特性 三、连接管理 1)什么是连接管理? 2)“三次握手”建立连接 1> 什么是“三次握手”? 2> “三次握手”的核心作用是什么? 3)“四次挥手”断开连接 1> 什么是“…

听GPT 讲Rust源代码--library/panic_unwind

File: rust/library/panic_unwind/src/seh.rs 在Rust源代码中,rust/library/panic_unwind/src/seh.rs这个文件的作用是实现Windows操作系统上的SEH(Structured Exception Handling)异常处理机制。 SEH是Windows上的一种异常处理机制&#xff…

Mysql 动态链接库配置步骤+ 完成封装init和close接口

1、创建新项目 动态链接库dll 2、将附带的文件都删除,创建LXMysql.cpp 3、项目设置 3.1、预编译头,不使用预编译头 3.2、添加头文件 3.3、添加类 3.4、写初始化函数 4、项目配置 4.1、右键解决方案-属性-常规-输出目录 ..\..\bin 4.2、生成lib文件 右…

3D视觉-相机选用的原则

鉴于不同技术方案都有其适用的场景,立体相机的选型讲究的原则为“先看用途,再看场景,终评精度”,合适的立体相机在方案中可以起到事半功倍的效果。从用途上来进行划分,三维视觉方案主要应用在两个方向:测量…

Linux 进程(六) 环境变量

main函数参数: 这是一个常见的main函数,那么main函数可以带参吗? int main() {return 0; } 答案是可以的! 我们先看这样一段代码,首先给main函数带上两个参数。 然后我们来看输出的结果。 这样一组字符串是命令行解释…

【AI】一文读懂大模型套壳——神仙打架?软饭硬吃?

目录 一、套壳的风波此起彼伏 二、到底什么是大模型的壳 2.1 大模型的3部分,壳指的是哪里 大模型的内核 预训练(Pre-training) 调优(Fine-tuning) 2.2 内核的发展历程和万流归宗 2.3 套壳不是借壳 三、软饭硬…