源码
import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)class GoogleNet(nn.Module):def __init__(self, Inception):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (128, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(1024, 10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0 ,0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogleNet(Inception).to(device)print(summary(model, (1, 224, 224)))
从整个链路上看,googlenet的复杂度相比于之前我们提到的cnn网络更复杂。仔细分析可以看到,googlenet的网络结构里面有多个核心模块inception。搞懂inception就基本搞清楚了googlenet。
Inception
Inception 模块的设计动机
-
传统串联卷积的局限性
- 传统网络通过堆叠卷积层逐步提取特征,但不同尺度的特征(如边缘、纹理、物体部件)需不同大小的卷积核。
- 堆叠大卷积核(如 5x5)会导致计算量暴增(参数会增加很多)。
-
关键优化目标
- 多尺度特征融合:同时提取不同尺度的特征。
- 减少计算量:通过 1x1 卷积降维,控制参数规模。
Inception模块设计思路
- 并行多分支设计:Inception模块包含多个并行分支,典型结构包括1x1卷积、3x3卷积、5x5卷积和3x3最大池化层。不同尺寸的卷积核可同时捕捉局部细节和全局特征。
- 特征图拼接:各分支输出的特征图在通道维度进行拼接,形成综合特征表达,增强模型对不同尺度的适应性。从图片可以看到,每个inception块有四条路径,之前的cnn大多是单一路径。
class Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)
从代码可以看出,每个inception块都分成了四个路径。1,2,3路径都是纯卷积,第四条路径是池化层+卷积。另外,卷积核的大小是固定的,卷积核的通道数是可以通过传参设置的。
传参如下表所示:
参数 | 含义 | 示例值 |
---|---|---|
in_channels | 输入特征图的通道数 | 192 |
c1 | 路径1的输出通道数 | 64 |
c2 | 路径2的通道数元组 (降维, 输出) | (96, 128) |
c3 | 路径3的通道数元组 (降维, 输出) | (16, 32) |
c4 | 路径4的输出通道数 | 32 |
总输出通道数 = c1 + c2 + c3 + c4
。示例:64 + 128 + 32 + 32 = 256。
前向传播
当时写代码,我有一个疑问,inception里的前向传播是什么时候触发的,是googlenet在处理block代码流程的时候自动触发的吗?
这个问题涉及到forward方法的隐式调用。
在PyTorch中,当通过 模块实例直接调用输入数据 时,forward
方法会被自动触发。例如:
inception = Inception(...) # 实例化模块
output = inception(x) # 隐式调用forward(x)
所以在googlenet前向传播的时候,完成了inception的前向传播。
另外在学习这块还学到个小知识,就是forward方法不能显式调用。会绕过一些关键步骤(如梯度计算),就导致无法反向传播了!
张量拼接
在PyTorch中,torch.cat((p1, p2, p3, p4), dim=1)
这句话的作用是沿着通道维度(channel dimension)将四个张量(p1, p2, p3, p4)拼接成一个更大的张量。以下是详细解释:
假设输入张量 x
的形状为 (batch_size, in_channels, height, width)
,经过Inception模块的四条路径处理后,每个路径的输出形状如下:
-
p1
:(batch_size, c1, height, width)
(1x1卷积直接输出c1
个通道) -
p2
:(batch_size, c2, height, width)
(1x1卷积降维到c2
,再通过3x3卷积输出c2
个通道) -
p3
:(batch_size, c3, height, width)
(1x1卷积降维到c3
,再通过5x5卷积输出c3
个通道) -
p4
:(batch_size, c4, height, width)
(最大池化后通过1x1卷积输出c4
个通道)
所有路径输出的高度(height)和宽度(width)必须一致,否则拼接会失败。批数量和通道数可以不相同。
可以在别的维度拼接吗?不太行,原因是:
dim=0
:沿批量维度拼接,会合并不同样本的数据,破坏批量独立性。dim=2/3
:沿空间维度拼接,会破坏特征图的空间结构,导致后续卷积无法正常操作。
参数初始化
# 遍历模型的所有子模块(包括嵌套模块)
for m in self.modules():# 对二维卷积层进行初始化if isinstance(m, nn.Conv2d):# 使用Kaiming正态分布初始化权重(针对ReLU激活函数优化)nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)# 对全连接层进行初始化 elif isinstance(m, nn.Linear):# 使用正态分布初始化权重(均值0,标准差0.01)nn.init.normal_(m.weight, 0, 0.01)# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)
在构建方法里我们增加了参数初始化,参数初始化主要作用是提高收敛速度,减少训练模型时压根不收敛的风险。
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
卷积层使用的是kaiming初始化,和relu激活函数搭配使用效果较好。两个参数的含义是:
mode="fan_out"
:根据输出通道数计算缩放系数nonlinearity='relu'
:针对ReLU的负半轴修正
nn.init.constant_(m.bias, 0)
卷积层如果存在偏置就统一初始化为0,避免初始阶段引入偏置。
全连接层使用的是小标准差正态分布,作用是限制初始权重范围,防止激活值过大。适用于浅层网络。
一些初始化方法的特点和适用场景:
方法 | 适用场景 | 核心思想 | PyTorch实现函数 |
---|---|---|---|
Kaiming初始化 | ReLU激活的CNN | 保持前向传播的方差一致性 | kaiming_normal_ /uniform_ |
Xavier初始化 | Tanh/Sigmoid激活 | 平衡输入输出的方差 | xavier_normal_ |
零初始化 | 偏置项 | 避免初始偏好 | constant_(0) |
正交初始化 | RNN/Transformer | 保持矩阵正交性,防止梯度爆炸 | orthogonal_ |