基于CNN的FashionMNIST数据集识别4——GoogleNet模型

源码

import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)class GoogleNet(nn.Module):def __init__(self, Inception):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (128, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(1024, 10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0 ,0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogleNet(Inception).to(device)print(summary(model, (1, 224, 224)))

从整个链路上看,googlenet的复杂度相比于之前我们提到的cnn网络更复杂。仔细分析可以看到,googlenet的网络结构里面有多个核心模块inception。搞懂inception就基本搞清楚了googlenet。

Inception

Inception 模块的设计动机

  1. 传统串联卷积的局限性

    • 传统网络通过堆叠卷积层逐步提取特征,但不同尺度的特征(如边缘、纹理、物体部件)需不同大小的卷积核。
    • 堆叠大卷积核(如 5x5)会导致计算量暴增(参数会增加很多)。
  2. 关键优化目标

    • 多尺度特征融合‌:同时提取不同尺度的特征。
    • 减少计算量‌:通过 1x1 卷积降维,控制参数规模。

Inception模块设计思路

  • 并行多分支设计‌:Inception模块包含多个并行分支,典型结构包括1x1卷积、3x3卷积、5x5卷积和3x3最大池化层。不同尺寸的卷积核可同时捕捉局部细节和全局特征‌。
  • 特征图拼接‌:各分支输出的特征图在通道维度进行拼接,形成综合特征表达,增强模型对不同尺度的适应性‌。从图片可以看到,每个inception块有四条路径,之前的cnn大多是单一路径。
class Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)

从代码可以看出,每个inception块都分成了四个路径。1,2,3路径都是纯卷积,第四条路径是池化层+卷积。另外,卷积核的大小是固定的,卷积核的通道数是可以通过传参设置的。

 传参如下表所示:

参数含义示例值
in_channels输入特征图的通道数192
c1路径1的输出通道数64
c2路径2的通道数元组 (降维, 输出)(96, 128)
c3路径3的通道数元组 (降维, 输出)(16, 32)
c4路径4的输出通道数32

总输出通道数 = c1 + c2 + c3 + c4。示例:64 + 128 + 32 + 32 = 256。

前向传播

当时写代码,我有一个疑问,inception里的前向传播是什么时候触发的,是googlenet在处理block代码流程的时候自动触发的吗?

这个问题涉及到forward方法的隐式调用。在PyTorch中,当通过 ‌模块实例直接调用输入数据‌ 时,forward 方法会被自动触发。例如:

inception = Inception(...)  # 实例化模块
output = inception(x)       # 隐式调用forward(x)

所以在googlenet前向传播的时候,完成了inception的前向传播。

另外在学习这块还学到个小知识,就是forward方法不能显式调用。会绕过一些关键步骤(如梯度计算),就导致无法反向传播了!

张量拼接

在PyTorch中,torch.cat((p1, p2, p3, p4), dim=1) 这句话的作用是‌沿着通道维度(channel dimension)将四个张量(p1, p2, p3, p4)拼接成一个更大的张量‌。以下是详细解释:

假设输入张量 x 的形状为 (batch_size, in_channels, height, width),经过Inception模块的四条路径处理后,每个路径的输出形状如下:

  • p1: (batch_size, c1, height, width)
    (1x1卷积直接输出c1个通道)

  • p2: (batch_size, c2, height, width)
    (1x1卷积降维到c2,再通过3x3卷积输出c2个通道)

  • p3: (batch_size, c3, height, width)
    (1x1卷积降维到c3,再通过5x5卷积输出c3个通道)

  • p4: (batch_size, c4, height, width)
    (最大池化后通过1x1卷积输出c4个通道)

所有路径输出的‌高度(height)和宽度(width)必须一致‌,否则拼接会失败。批数量和通道数可以不相同。

可以在别的维度拼接吗?不太行,原因是:

  • dim=0:沿批量维度拼接,会合并不同样本的数据,破坏批量独立性。
  • dim=2/3:沿空间维度拼接,会破坏特征图的空间结构,导致后续卷积无法正常操作。

 

参数初始化

# 遍历模型的所有子模块(包括嵌套模块)
for m in self.modules():# 对二维卷积层进行初始化if isinstance(m, nn.Conv2d):# 使用Kaiming正态分布初始化权重(针对ReLU激活函数优化)nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)# 对全连接层进行初始化        elif isinstance(m, nn.Linear):# 使用正态分布初始化权重(均值0,标准差0.01)nn.init.normal_(m.weight, 0, 0.01)# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)

在构建方法里我们增加了参数初始化,参数初始化主要作用是提高收敛速度,减少训练模型时压根不收敛的风险。

nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')

卷积层使用的是kaiming初始化,和relu激活函数搭配使用效果较好。两个参数的含义是:

  • mode="fan_out":根据输出通道数计算缩放系数
  • nonlinearity='relu':针对ReLU的负半轴修正
nn.init.constant_(m.bias, 0)

卷积层如果存在偏置就统一初始化为0,避免初始阶段引入偏置。

全连接层使用的是小标准差正态分布‌,作用是限制初始权重范围,防止激活值过大。适用于浅层网络。

一些初始化方法的特点和适用场景:

方法适用场景核心思想PyTorch实现函数
Kaiming初始化ReLU激活的CNN保持前向传播的方差一致性kaiming_normal_/uniform_
Xavier初始化Tanh/Sigmoid激活平衡输入输出的方差xavier_normal_
零初始化偏置项避免初始偏好constant_(0)
正交初始化RNN/Transformer保持矩阵正交性,防止梯度爆炸orthogonal_

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/37120.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试题精选《剑指Offer》:JVM类加载机制与Spring设计哲学深度剖析-大厂必考

一、JVM类加载核心机制 🔥 问题5:类从编译到执行的全链路过程 完整生命周期流程图 关键技术拆解 编译阶段 查看字节码指令:javap -v Robot.class 常量池结构解析(CONSTANT_Class_info等) 类加载阶段 // 手动加载…

(2025|ICLR|华南理工,任务对齐,缓解灾难性遗忘,底层模型冻结和训练早停)语言模型持续学习中的虚假遗忘

Spurious Forgetting in Continual Learning of Language Models 目录 1. 引言 2. 动机:关于虚假遗忘的初步实验 3. 深入探讨虚假遗忘 3.1 受控实验设置 3.2 从性能角度分析 3.3 从损失景观角度分析 3.4 从模型权重角度分析 3.5 从特征角度分析 3.6 结论 …

【css酷炫效果】纯CSS实现火焰文字特效

【css酷炫效果】纯CSS实现火焰文字特效 缘创作背景html结构css样式完整代码基础版进阶版(冰霜版) 效果图 想直接拿走的老板,链接放在这里:https://download.csdn.net/download/u011561335/90492005 缘 创作随缘,不定时更新。 创作背景 刚…

专访LayaAir引擎最有价值专家-施杨

在 LayaAir 引擎的资源商店中,许多开发者都会注意到一个熟悉的名字——“射手座”。他不仅贡献了大量高质量的 Shader 资源,让一些开发者通过他的作品了解到 LayaAir 引擎在 3D 视觉效果上的更多可能,也让大家能够以低成本直接学习并应用这些…

大模型详细配置

Transformer结构 目前主力大模型都是基于Transformer的,以下是Transformer的具体架构 它由编码器(Encoder)以及解码器(Decoder)组成,前者主要负责对输入数据进行理解,将每个输入 词元都编码成一个上下文语义相关的表示向量;后者…

鸿蒙NEXT项目实战-百得知识库04

代码仓地址,大家记得点个star IbestKnowTeach: 百得知识库基于鸿蒙NEXT稳定版实现的一款企业级开发项目案例。 本案例涉及到多个鸿蒙相关技术知识点: 1、布局 2、配置文件 3、组件的封装和使用 4、路由的使用 5、请求响应拦截器的封装 6、位置服务 7、三…

Python数据可视化实战:从基础图表到高级分析

Python数据可视化实战:从基础图表到高级分析 数据可视化是数据分析的重要环节,通过直观的图表可以快速洞察数据规律。本文将通过5个实际案例,手把手教你使用Python的Matplotlib库完成各类数据可视化任务,涵盖条形图、堆积面积图、…

修改原生的<input type=“datetime-local“>样式

效果 基础样式 <input type"datetime-local" class"custom-datetime">input[type"datetime-local"] {/* 重置默认样式 */-webkit-appearance: none;-moz-appearance: none;appearance: none; // 禁用浏览器默认样式/* 自定义基础样式 */w…

scrapy入门(深入)

Scrapy框架简介 Scrapy是:由Python语言开发的一个快速、高层次的屏幕抓取和web抓取框架&#xff0c;用于抓取web站点并从页面中提取结构化的数据&#xff0c;只需要实现少量的代码&#xff0c;就能够快速的抓取。 新建项目 (scrapy startproject xxx)&#xff1a;新建一个新的…

fetch,ajax,axios的区别以及使用

fetch,ajax,axios这些都是发起前端请求的工具&#xff0c;除了这些外还有jquery的$.ajax。ajax和$.ajax都是基于XMLHttpRequest。 介绍下XMLHttpRequest XMLHttpRequest是一种在浏览器中用于与服务器进行异步通信的对象&#xff0c;它是实现 AJAX&#xff08;Asynchronous Ja…

微信小程序的业务域名配置(通过ingress网关的注解)

一、背景 微信小程序的业务域名配置&#xff08;通过kong网关的pre-function配置&#xff09;是依靠kong实现&#xff0c;本文将通过ingress网关实现。 而我们的服务是部署于阿里云K8S容器&#xff0c;当然内核与ingress无异。 找到k8s–>网络–>路由 二、ingress注解 …

LiteratureReading:[2016] Enriching Word Vectors with Subword Information

文章目录 一、文献简明&#xff08;zero&#xff09;二、快速预览&#xff08;first&#xff09;1、标题分析2、作者介绍3、引用数4、摘要分析&#xff08;1&#xff09;翻译&#xff08;2&#xff09;分析 5、总结分析&#xff08;1&#xff09;翻译&#xff08;2&#xff09;…

前后端联调解决跨域问题的方案

引言 在前后端分离的开发模式中&#xff0c;前端和后端通常在不同的服务器或端口运行&#xff0c;这样就会面临跨域问题。跨域问题是指浏览器因安全限制阻止前端代码访问与当前网页源不同的域、协议或端口的资源。对于 Java 后端应用&#xff0c;我们可以通过配置 CORS&#x…

开源软件许可证冲突的原因和解决方法

1、什么是开源许可证以及许可证冲突产生的问题 开源软件许可证是一种法律文件&#xff0c;它规定了软件用户、分发者和修改者使用、复制、修改和分发开源软件的权利和义务。开源许可证是由软件的版权所有者&#xff08;通常是开发者或开发团队&#xff09;发布的&#xff0c;它…

python爬虫笔记(一)

文章目录 html基础标签和下划线无序列表和有序列表表格加边框 html的属性a标签&#xff08;网站&#xff09;target属性换行线和水平分割线 图片设置宽高width&#xff0c;height html区块——块元素与行内元素块元素与行内元素块元素举例行内元素举例 表单from标签type属性pla…

电脑节电模式怎么退出 分享5种解决方法

在使用电脑的过程中&#xff0c;许多用户为了节省电力&#xff0c;通常会选择开启电脑的节能模式。然而&#xff0c;在需要更高性能或进行图形密集型任务时&#xff0c;节能模式可能会限制系统的性能表现。这时&#xff0c;了解如何正确地关闭或调整节能设置就显得尤为重要了。…

AI学习——卷积神经网络(CNN)入门

作为人类&#xff0c;我们天生擅长“看”东西&#xff1a;一眼就能认出猫狗、分辨红绿灯、读懂朋友的表情……但计算机的“眼睛”最初是一片空白。直到卷积神经网络&#xff08;CNN&#xff09;​的出现&#xff0c;计算机才真正开始理解图像。今天&#xff0c;我们就用最通俗的…

2025年渗透测试面试题总结- shopee-安全工程师(题目+回答)

网络安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 shopee-安全工程师 信息安全相关Response头详解 1. 关键安全头及防御场景 Linux与Docker核心命令速查…

IntelliJ IDEA 中 Maven 的 `pom.xml` 变灰带横线?一文详解解决方法

前言 在使用 IntelliJ IDEA 进行 Java 开发时&#xff0c;如果你发现项目的 pom.xml 文件突然变成灰色并带有删除线&#xff0c;这可能是 Maven 的配置或项目结构出现了问题。 一、问题现象与原因分析 现象描述 文件变灰&#xff1a;pom.xml 在项目资源管理器中显示为灰色。…

Spring MVC 接口数据

访问路径设置 RequestMapping("springmvc/hello") 就是用来向handlerMapping中注册的方法注解! 秘书中设置路径和方法的对应关系&#xff0c;即RequestMapping("/springmvc/hello")&#xff0c;设置的是对外的访问地址&#xff0c; 路径设置 精准路径匹…