CNN-day5-经典神经网络LeNets5

经典神经网络-LeNets5

1998年Yann LeCun等提出的第一个用于手写数字识别问题并产生实际商业(邮政行业)价值的卷积神经网络

参考:论文笔记:Gradient-Based Learning Applied to Document Recognition-CSDN博客

1 网络模型结构

整体结构解读:

输入图像:32×32×1

三个卷积层:

C1:输入图片32×32,6个5×5卷积核 ,输出特征图大小28×28(32-5+1)=28,一个bias参数;

可训练参数一共有:(5×5+1)×6=156

C3 :输入图片14×14,16个5×5卷积核,有6×3+6×4+3×4+1×6=60个通道,输出特征图大小10×10((14-5)/1+1),一个bias参数;

可训练参数一共有:6(3×5×5+1)+6×(4×5×5+1)+3×(4×5×5+1)+1×(6×5×5+1)=1516

C3的非密集的特征图连接:

C3的前6个特征图与S2层相连的3个特征图相连接,后面6个特征图与S2层相连的4个特征图相连 接,后面3个特征图与S2层部分不相连的4个特征图相连接,最后一个与S2层的所有特征图相连。 采用非密集连接的方式,打破对称性,同时减少计算量,共60组卷积核。主要是为了节省算力。

C5:输入图片5×5,16个5×5卷积核,包括120×16个5×5卷积核 ,输出特征图大小1×1(5-5+1),一个bias参数;

可训练参数一共有:120×(16×5×5+1)=48120

两个池化层S2和S4:

都是2×2的平均池化,并添加了非线性映射

S2(下采样层):输入28×28,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:14×14(28/2)

S4(下采样层):输入10×10,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:5×5(10/2)

两个全连接层:

第一个全连接层:输入120维向量,输出84个神经元,计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过sigmoid函数输出。84的原因是:字符编码是ASCII编码,用7×12大小的位图表示,-1白色1黑色,84可以用于对每一个像素点的值进行估计。

第二个全连接层(Output层-输出层):输出 10个神经元 ,共有10个节点,代表数字0-9。

所有激活函数采用Sigmoid

2 网络模型实现

2.1模型定义

import torch
import torch.nn as nn
​
​
class LeNet5s(nn.Module):def __init__(self):super(LeNet5s, self).__init__()  # 继承父类# 第一个卷积层self.C1 = nn.Sequential(nn.Conv2d(in_channels=1,  # 输入通道out_channels=6,  # 输出通道kernel_size=5,  # 卷积核大小),nn.ReLU(),)# 池化:平均池化self.S2 = nn.AvgPool2d(kernel_size=2)
​# C3:3通道特征融合单元self.C3_unit_6x3 = nn.Conv2d(in_channels=3,out_channels=1,kernel_size=5,)# C3:4通道特征融合单元self.C3_unit_6x4 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:4通道特征融合单元,剔除中间的1通道self.C3_unit_3x4_pop1 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:6通道特征融合单元self.C3_unit_1x6 = nn.Conv2d(in_channels=6,out_channels=1,kernel_size=5,)
​# S4:池化self.S4 = nn.AvgPool2d(kernel_size=2)# 全连接层self.fc1 = nn.Sequential(nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=120, out_features=84), nn.ReLU())self.fc3 = nn.Linear(in_features=84, out_features=10)
​def forward(self, x):# 训练数据批次大小batch_sizenum = x.shape[0]
​x = self.C1(x)x = self.S2(x)# 生成一个empty张量outchannel = torch.empty((num, 0, 10, 10))# 6个3通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 3)])x1 = self.C3_unit_6x3(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 6个4通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 4)])x1 = self.C3_unit_6x4(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 3个4通道的单元,先拿五个,干掉中那一个for i in range(3):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 5)])# 删除第三个元素channel_idx = channel_idx[:2] + channel_idx[3:]print(channel_idx)x1 = self.C3_unit_3x4_pop1(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​x1 = self.C3_unit_1x6(x)# 平均池化outchannel = torch.cat([outchannel, x1], dim=1)outchannel = nn.ReLU()(outchannel)
​x = self.S4(outchannel)# 对数据进行变形x = x.view(x.size(0), -1)# 全连接层x = self.fc1(x)x = self.fc2(x)# TODO:SOFTMAXoutput = self.fc3(x)
​return output
​
​
def test001():net = LeNet5s()# 随机一个测试数据input = torch.randn(128, 1, 32, 32)output = net(input)print(output.shape)pass
​
​
if __name__ == "__main__":test001()

2.2全局变量

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
​
dir = os.path.dirname(__file__)
modelpath = os.path.join(dir, "weight/model.pth")
datapath = os.path.join(dir, "data")
​
# 数据预处理和加载
transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整输入图像大小为32x32transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),]
)
​

2.3模型训练

def train():
​trainset = torchvision.datasets.MNIST(root=datapath, train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
​# 实例化模型net = LeNet5()
​# 使用MSELoss作为损失函数criterion = nn.MSELoss()
​# 使用SGD优化器optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
​# 训练模型num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data
​# 将labels转换为one-hot编码labels_one_hot = torch.zeros(labels.size(0), 10).scatter_(1, labels.view(-1, 1), 1.0)labels_one_hot = labels_one_hot.to(torch.float32)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels_one_hot)loss.backward()optimizer.step()
​running_loss += loss.item()if i % 100 == 99:print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")running_loss = 0.0# 保存模型参数torch.save(net.state_dict(), modelpath)print("Finished Training")

2.4验证

def vaild():
​testset = torchvision.datasets.MNIST(root=datapath, train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)# 实例化模型net = LeNet5()net.load_state_dict(torch.load(modelpath))# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
​print(f"验证集: {100 * correct / total:.2f}%")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/14415.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[渗透测试]热门搜索引擎推荐— — shodan篇

[渗透测试]热门搜索引擎推荐— — shodan篇 免责声明:本文仅用于分享渗透测试工具,大家使用时,一定需要遵守相关法律法规。 除了shodan,还有很多其他热门的,比如:fofa、奇安信的鹰图、钟馗之眼等&#xff0…

BUU30 [网鼎杯 2018]Fakebook1

是一个登录界面&#xff0c;我们先注册一个试试&#xff1a; 用dirsearch扫描出来robots.txt&#xff0c;也发现了flag.php&#xff0c;并下载user.php.bak 源代码内容&#xff1a; <?phpclass UserInfo {public $name "";public $age 0;public $blog &quo…

索引失效的场景

chatGpt 7. 使用 DISTINCT 或 GROUP BY 当查询中涉及 DISTINCT 或 GROUP BY 时&#xff0c;如果查询没有合适的索引支持&#xff0c;可能会导致性能问题&#xff0c;虽然不完全是索引失效&#xff0c;但会影响查询效率。 sql SELECT DISTINCT department_id FROM employees;…

3D数字化营销:重塑家居电商新生态

随着电商的蓬勃发展&#xff0c;网上订购家具已成为众多消费者的首选。然而&#xff0c;线上选购家具的诸多挑战&#xff0c;如风格不匹配、尺寸不合适、定制效果不如预期以及退换货不便等&#xff0c;一直困扰着消费者。为解决这些问题&#xff0c;家居行业急需一种全新的展示…

论文阅读--LlaVA

数据 使用GPT-4&#xff0c;根据现有的图片对数据&#xff08;image-pair data&#xff09;收集指令跟随数据。作者团队收集了158,000个独特的语言-图像指令遵循样本&#xff0c;其中包括58,000个对话样本、23,000个详细描述样本和77,000个复杂推理样本 以图像描述为例&#x…

【R语言】apply函数族

在R语言中使用循环操作时是使用自身来实现的&#xff0c;效率较低。所以R语言有一个符合其统计语言出身的特点&#xff1a;向量化。R语言中的向量化运用了底层的C语言&#xff0c;而C语言的效率比高层的R语言的效率高。 apply函数族主要是为了解决数据向量化运算的问题&#x…

归一化与伪彩:LabVIEW图像处理的区别

在LabVIEW的图像处理领域&#xff0c;归一化&#xff08;Normalization&#xff09;和伪彩&#xff08;Pseudo-coloring&#xff09;是两个不同的概念&#xff0c;虽然它们都涉及图像像素值的调整&#xff0c;但目的和实现方式截然不同。归一化用于调整像素值的范围&#xff0c…

【3分钟极速部署】在本地快速部署deepseek

第一步&#xff0c;找到网站&#xff0c;下载&#xff1a; 首先找到Ollama &#xff0c; 根据自己的电脑下载对应的版本 。 我个人用的是Windows 我就先尝试用Windows版本了 &#xff0c;文件不是很大&#xff0c;下载也比较的快 第二部就是安装了 &#xff1a; 安装完成后提示…

论文阅读:MGMAE : Motion Guided Masking for Video Masked Autoencoding

MGMAE:Motion Guided Masking for Video Masked Autoencoding Abstract 掩蔽自编码&#xff08;Masked Autoencoding&#xff09;在自监督视频表示学习中展现了出色的表现。时间冗余导致了VideoMAE中高掩蔽比率和定制的掩蔽策略。本文旨在通过引入运动引导掩蔽策略&#xff0…

【Ai】--- 可视化 DeepSeek-r1 接入 Chatbox(超详细)

在编程的艺术世界里&#xff0c;代码和灵感需要寻找到最佳的交融点&#xff0c;才能打造出令人为之惊叹的作品。而在这座秋知叶i博客的殿堂里&#xff0c;我们将共同追寻这种完美结合&#xff0c;为未来的世界留下属于我们的独特印记。 【Ai】--- 可视化 DeepSeek-r1 接入 Chat…

P1049 装箱问题(dp)

#include<bits/stdc.h> using namespace std;int main() {int v,n;cin>>v>>n;int a[30];int dp[20005];for(int i0;i<n;i){cin>>a[i];}memset(dp,0,sizeof(dp));// 设置所有元素为0&#xff0c;表示最大体积为0for(int i0;i<n;i){for(int jv;j&…

程序诗篇里的灵动笔触:指针绘就数据的梦幻蓝图<7>

大家好啊&#xff0c;我是小象٩(๑ω๑)۶ 我的博客&#xff1a;Xiao Xiangζั͡ޓއއ 很高兴见到大家&#xff0c;希望能够和大家一起交流学习&#xff0c;共同进步。 今天我们一起来学习转移表&#xff0c;回调函数&#xff0c;qsort… 目录 一、转移表1.1 定义与原理1.3…

声明式导航,编程式导航,导航传参,下拉刷新

1.页面导航 1.声明式导航 1.1跳转到tabBar页面 1.2跳转到非tabBar页面 1.2后退导航 、 2.编程式导航 2.1跳转到tabBar页面 2.1跳转到非tabBar页面 2.3后退导航 3.导航传参 3.1声名式导航传参 3.2编程式导航传参 3.3在onLoad中接受参数 4.下拉刷新 4.1回顾下拉刷新…

C++ Primer 递增和递减运算符

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…

【C++高并发服务器WebServer】-13:多线程服务器开发

本文目录 一、多线程服务器开发二、TCP状态转换三、端口复用 一、多线程服务器开发 服务端代码如下。 #include <stdio.h> #include <arpa/inet.h> #include <unistd.h> #include <stdlib.h> #include <string.h> #include <pthread.h>s…

重生之我要当云原生大师(十一)访问Linux文件系统

目录 一、解释下文件系统、块设备、挂载点、逻辑卷。 二、简述文件系统、块设备、挂载点、逻辑卷之间的关系&#xff1f; 三、如何检查文件系统&#xff1f; 四、挂载和卸载文件系统的流程&#xff1f; 五、find命令都可以根据什么查找文件。 一、解释下文件系统、块设备、…

NetCore Consul动态伸缩+Ocelot 网关 缓存 自定义缓存 + 限流、熔断、超时 等服务治理 + ids4鉴权

网关 OcelotGeteway 网关 Ocelot配置文件 {//单地址多实例负载均衡Consul 实现动态伸缩"Routes": [{// 上游 》》 接受的请求//上游请求方法,可以设置特定的 HTTP 方法列表或设置空列表以允许其中任何方法"UpstreamHttpMethod": [ "Get", &quo…

星网锐捷 DMB-BS LED屏信息发布系统taskexport接口处存在敏感信息泄露

星网锐捷 DMB-BS LED屏信息发布系统taskexport接口处存在敏感信息泄露 漏洞描述 福建星网锐捷通讯股份有限公司成立于2000年,公司秉承“融合创新科技,构建智慧未来"的经营理念,是国内领先的ICT基础设施及AI应用方案提供商。星网锐捷 DMB-BS LED屏信息发布系统taskexp…

国产高端双光子成像系统的自主突破

近年来&#xff0c;高端科研仪器的国产化受到越来越多的关注。在双光子成像系统这一关键领域&#xff0c;我们基于LabVIEW自主开发了一套完整的解决方案&#xff0c;不仅填补了国内空白&#xff0c;也在功能和性能上达到了国际领先水平。我们的目标是让国内科研机构和医疗行业拥…

Python多版本管理

关注后回复 python 获取相关资料 ubuntu18.04 # ubuntu18 默认版本 Python 2.7.17 apt install python python-dev python-pip# ubuntu18 默认版本 Python 3.6.9 apt install python3 python3-dev python3-pip# ubuntu18 使用 python3.8 apt install python3.8 python3.8-dev#…