Pytorch | 从零构建ResNet对CIFAR10进行分类

Pytorch | 从零构建ResNet对CIFAR10进行分类

  • CIFAR10数据集
  • ResNet
      • 核心思想
      • 网络结构
      • 创新点
      • 优点
      • 应用
  • ResNet结构代码详解
    • 结构代码
    • 代码详解
      • BasicBlock 类
      • ResNet 类
      • ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数
  • 训练过程和测试结果
  • 代码汇总
    • resnet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
这篇文章我们来构建ResNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,它在ILSVRC 2015图像识别挑战赛中取得了优异成绩,在图像分类、目标检测、语义分割等计算机视觉任务中具有广泛应用。以下是对ResNet的详细介绍:

核心思想

  • 解决梯度消失和退化问题:随着神经网络层数的增加,会出现梯度消失或梯度爆炸问题,导致模型难以训练。同时,还会出现网络退化现象,即增加网络层数后,准确率反而下降。ResNet的核心思想是引入残差连接(Residual Connection),通过跨层的shortcut连接,将输入直接传递到后面的层,使得后面的层可以学习到输入的残差,从而缓解了梯度消失和网络退化问题。

网络结构

  • 基本残差块:ResNet的基本组成单元是残差块(Residual Block)。一个典型的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,并且在第二个卷积层之后也有一个ReLU激活函数。输入通过一个shortcut连接直接与残差块的输出相加,形成残差学习。
  • 不同层数的架构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。其中,数字表示网络中卷积层和全连接层的总层数。层数越深,模型的表示能力越强,但计算成本也越高。

创新点

  • 瓶颈结构:在ResNet-50及更深的网络中,采用了瓶颈结构(Bottleneck)的残差块。这种结构先使用1×1卷积层进行降维,然后使用3×3卷积层进行特征提取,最后再使用1×1卷积层进行升维,这样可以在减少计算量的同时增加网络的深度和宽度,提高模型的性能。
  • 全局平均池化:在网络的最后一层,ResNet采用了全局平均池化(Global Average Pooling)代替传统的全连接层进行分类。全局平均池化可以将每个特征图的空间维度压缩为一个值,得到一个固定长度的特征向量,然后直接输入到分类器中进行分类。

优点

  • 训练深度网络更容易:残差连接使得梯度能够更有效地在网络中传播,大大降低了训练深度网络的难度,使得可以成功训练上百层甚至上千层的网络。
  • 性能出色:在各种图像识别任务中,ResNet都取得了非常出色的性能,相比之前的网络结构,具有更高的准确率和更好的泛化能力。
  • 模型可扩展性强:可以方便地通过增加残差块的数量来扩展网络的深度,以适应不同的任务和数据集需求。

应用

  • 图像分类:ResNet在图像分类任务中取得了巨大成功,如在ImageNet数据集上达到了很高的准确率,成为了图像分类领域的主流模型之一。
  • 目标检测:与其他目标检测算法结合,如Faster R-CNN、YOLO等,通过提取图像的特征,提高目标检测的准确率和召回率。
  • 语义分割:用于对图像进行像素级的分类,将图像中的每个像素分配到不同的类别中,在城市景观分割、医学图像分割等领域有广泛应用。

ResNet结构代码详解

结构代码

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

代码详解

以下是对上述提供的PyTorch代码的详细解释,这段代码实现了经典的ResNet(残差网络)系列模型,包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等不同深度的网络架构:

BasicBlock 类

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))
  • 类定义与属性
    • 定义了一个名为BasicBlock的类,继承自nn.Module,这是PyTorch中定义神经网络模块的基类。
    • expansion属性被设置为1,用于表示该基本块在通道维度上的扩展倍数,在BasicBlock中通道数不会进行额外的扩展(后续的Bottleneck块会有不同的扩展倍数)。
  • 初始化方法__init__
    • 首先调用父类nn.Module的初始化方法super(BasicBlock, self).__init__(),确保模块正确初始化。
    • 定义了两个卷积层conv1conv2
      • conv1:输入通道数为in_channels,输出通道数为out_channels,卷积核大小为3×3,步长为stride,填充为1,并且不使用偏置(bias=False),这是遵循ResNet论文中的实现方式,通常配合后续的BatchNorm使用。
      • conv2:输入通道数为out_channels,输出通道数为out_channels * BasicBlock.expansion(实际就是out_channels,因为expansion1),卷积核大小同样是3×3,填充为1,无偏置。
    • 定义了两个BatchNorm2dbn1bn2,分别对应两个卷积层之后,用于对卷积后的特征进行归一化处理,有助于加速训练和提高模型的稳定性。
    • 定义了一个ReLU激活函数relu,并且设置inplace=True,表示直接在原张量上进行激活操作,节省内存空间(但要注意使用不当可能导致梯度计算问题,如前面提到的错误情况)。
    • 定义了shortcut,初始化为一个空的nn.Sequential序列。当输入和输出的通道数不一致或者步长不为1时(意味着尺寸或通道数有变化),会重新构建shortcut,使其包含一个1×1卷积层(用于调整通道数)和一个BatchNorm2d层,以保证shortcut连接的特征维度能与主分支的输出特征维度相匹配,便于后续进行相加操作。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return out
  • 前向传播方法forward
    • 首先将输入x经过conv1卷积、bn1归一化后,再通过relu激活函数得到中间特征。
    • 接着将该中间特征再经过conv2卷积和bn2归一化。
    • 然后将主分支得到的特征outshortcut分支(直接连接输入x经过调整后的特征)进行逐元素相加,实现残差连接的操作。
    • 最后再经过一次relu激活函数后返回结果,作为该基本块的输出。

ResNet 类

class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)
  • 类定义与属性
    • 定义了ResNet类,同样继承自nn.Module,用于构建完整的ResNet网络架构。
    • 初始化了一个属性in_channels64,用于记录当前层的输入通道数,后续会动态更新。
    • 定义了网络的起始层,包括一个3×3卷积层conv1(输入通道为3,对应彩色图像的RGB三个通道,输出通道为64),一个BatchNorm2dbn1用于归一化,一个ReLU激活函数relu,以及一个最大池化层maxpool(其参数设置按照常规的ResNet结构配置)。
    • 分别定义了layer1layer2layer3layer4这四层网络结构,它们通过调用_make_layer方法来构建,每层的输出通道数以及重复的块数量由传入的参数决定,并且随着层数加深,步长会相应改变(从第二层开始步长为2,用于逐步减小特征图尺寸)。
    • 定义了一个自适应平均池化层avgpool,它能将输入的特征图尺寸自适应地变为(1, 1)大小,无论输入特征图的尺寸原本是多少,便于后续全连接层处理。最后定义了一个全连接层fc,用于将池化后的特征映射到指定的类别数num_classes上进行分类。
    def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)
  • _make_layer方法
    • 这个方法用于构建ResNet中的每一层网络结构(由多个基本块组成)。
    • 首先根据传入的stridenum_blocks生成一个步长列表strides,例如,如果传入stride=2num_blocks=3,那么strides会是[2, 1, 1],意味着第一个基本块可能会改变特征图的尺寸和通道数,后面的基本块保持步长为1
    • 然后循环遍历strides列表,每次创建一个指定的block(可以是BasicBlock或者后续定义的Bottleneck块),并传入当前的输入通道数、输出通道数以及对应的步长,将创建好的块添加到layers列表中。同时,更新self.in_channels为当前块输出的通道数(考虑了块的扩展倍数)。
    • 最后将layers列表中的所有块组合成一个nn.Sequential序列并返回,形成一层完整的网络结构。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
  • 前向传播方法forward
    • 首先将输入x依次经过网络起始层的卷积、归一化、激活和池化操作,得到初步的特征表示。
    • 然后将该特征依次通过layer1layer2layer3layer4这四层网络结构,不断提取和融合特征,每一层都会进一步加深特征的抽象程度并且改变特征图的尺寸和通道数。
    • 接着经过自适应平均池化层avgpool,将特征图变为(1, 1)大小的特征向量。
    • 通过out.view(out.size(0), -1)操作将特征向量展平为一维向量,使其能输入到全连接层fc中。
    • 最后将全连接层的输出作为整个网络的最终输出,返回分类结果。

ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数

# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
  • 这两个函数分别用于创建ResNet-18ResNet-34网络模型。它们通过调用ResNet类的构造函数,传入BasicBlock作为构建块类型,以及对应不同层数的重复块数量列表(如ResNet-18中每层分别重复2个基本块),还有指定的类别数num_classes,最终返回构建好的相应深度的ResNet模型实例,用于图像分类等任务。
# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))
  • Bottleneck类定义与初始化
    • 定义了Bottleneck类,同样继承自nn.Module,用于构建更深层的ResNet网络(如ResNet-50及以上)中的基本块。
    • expansion属性被设置为4,意味着该块在经过一系列操作后,输出通道数会是输入通道数的4倍,通过这种方式在增加网络深度的同时控制计算量。
    • 在初始化方法中,定义了三个卷积层conv1conv2conv3,分别是1×1卷积用于降维、3×3卷积进行主要的特征提取、1×1卷积用于升维,并且每个卷积层后都有对应的BatchNorm2d层进行归一化,还有ReLU激活函数用于激活中间特征。
    • 同样定义了shortcut,其构建逻辑和BasicBlock中类似,根据输入输出通道数和步长情况来决定是否需要构建包含1×1卷积和BatchNorm2d层的调整结构,以保证残差连接的维度匹配。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return out
  • Bottleneck块的前向传播方法
    • 前向传播过程与BasicBlock类似,只是中间经过了三个卷积层及对应的归一化和激活操作,最后同样是将主分支特征与shortcut分支特征相加后再经过ReLU激活函数输出,实现残差学习。
def ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
  • 这几个函数分别用于创建ResNet-50ResNet-101ResNet-152网络模型,它们与创建ResNet-18ResNet-34的函数类似,只是传入的构建块类型变为Bottleneck,以及对应不同层数的重复Bottleneck块数量列表,还有指定的类别数num_classes,最终返回相应深度的ResNet模型实例,用于更复杂的图像分类等任务,这些更深层的网络结构在处理大规模图像数据集时往往能取得更好的性能表现。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:

在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models|--__init__.py|-resnet.py|--...
|--results
|--weights
|--train.py
|--test.py

resnet.py

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/493473.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gpu硬件架构

1.简介 NVIDIA在视觉计算和人工智能(AI)领域处于领先地位;其旗舰GPU已成为解决包括高性能计算和人工智能在内的各个领域复杂计算挑战所不可或缺的。虽然它们的规格经常被讨论,但很难掌握各种组件的清晰完整的图景。 这些GPU的高性…

Java中的方法重写:深入解析与最佳实践

在Java编程中,方法重写(Method Overriding)是面向对象编程(OOP)的核心概念之一。它允许子类提供一个与父类中同名方法的具体实现,从而实现多态性(Polymorphism)。本文将深入探讨Java…

使用vcpkg安装opencv>=4.9后#include<opencv2/opencv.hpp>#include<opencv2/core.hpp>无效

使用vcpkg安装opencv>4.9后#include<opencv2/opencv.hpp>#include<opencv2/core.hpp>无效\无法查找或打开 至少从2024年开始&#xff0c;发布的vcpkg默认安装的opencv版本都是4.x版。4.8版本及以前&#xff0c;vcpkg编译后的opencv头文件目录是*/vcpkg/x64-win…

基于java web在线商城购物系统源码+论文

一、环境信息 开发语言&#xff1a;JAVA JDK版本&#xff1a;JDK8及以上 数据库&#xff1a;MySql5.6及以上 Maven版本&#xff1a;任意版本 操作系统&#xff1a;Windows、macOS 开发工具&#xff1a;Idea、Eclipse、MyEclipse 开发框架&#xff1a;SpringbootHTMLjQueryMysq…

基于字节大模型的论文翻译(含免费源码)

基于字节大模型的论文翻译 源代码&#xff1a; &#x1f44f; star ✨ https://github.com/boots-coder/LLM-application 展示 项目简介 本项目是一个基于大语言模型&#xff08;Large Language Model, LLM&#xff09;的论文阅读与翻译辅助工具。它通过用户界面&#xff08…

centos7下docker 容器实现redis主从同步

1.下载redis 镜像 docker pull bitnami/redis2. 文件夹授权 此文件夹是 你自己映射到宿主机上的挂载目录 chmod 777 /app/rd13.创建docker网络 docker network create mynet4.运行docker 镜像 安装redis的master -e 是设置环境变量值 docker run -d -p 6379:6379 \ -v /a…

实现 WebSocket 接入文心一言

目录 什么是 WebSocket&#xff1f; 为什么需要 WebSocket&#xff1f; HTTP 的局限性 WebSocket 的优势 总结&#xff1a;HTTP 和 WebSocket 的区别 WebSocket 的劣势 WebSocket 常见应用场景 WebSocket 握手过程 WebSocket 事件处理和生命周期 WebSocket 心跳机制 …

机动车油耗计算API集成指南

机动车油耗计算API集成指南 引言 在当今社会&#xff0c;随着机动车数量的持续增长和环保意识的不断增强&#xff0c;如何有效管理和降低车辆油耗成为了车主、车队管理者以及交通政策制定者共同关注的问题。为了帮助这些群体更好地理解和优化燃油消耗情况&#xff0c;本接口能…

前端yarn工具打包时网络连接问题排查与解决

最近线上前端打包时提示 “There appears to be trouble with your network connection”&#xff0c;以此文档记录下排查过程。 前端打包方式 docker启动临时容器打包&#xff0c;命令如下 docker run --rm -w /app -v pwd:/app alpine-node-common:v16.20-pro sh -c "…

IIC I2C子协议 SMBus协议 通信协议原理 时序 SMBus深度剖析

引言&#xff1a;系统管理总线&#xff08;SMBus&#xff09;是一种双线接口&#xff0c;通过该接口&#xff0c;各种系统组件芯片和设备可以相互以及与系统其他部分通信。它基于IC总线的操作原理。附录B提供了一些SMBus特性与IC总线不同的方式的描述。 SMBus为系统和电源管理相…

【Lua热更新】上篇

Lua 热更新 - 上篇 下篇链接&#xff1a;【Lua热更新】下篇 文章目录 Lua 热更新 - 上篇一、AssetBundle1.理论2. AB包资源加载 二、Lua 语法1. 简单数据类型2.字符串操作3.运算符4.条件分支语句5.循环语句6.函数7. table数组8.迭代器遍历9.复杂数据类型 - 表9.1字典9.2类9.3…

React图标库: 使用React Icons实现定制化图标效果

React图标库: 使用React Icons实现定制化图标效果 图标库介绍 是一个专门为React应用设计的图标库&#xff0c;它包含了丰富的图标集合&#xff0c;覆盖了常用的图标类型&#xff0c;如FontAwesome、Material Design等。React Icons可以让开发者在React应用中轻松地添加、定制各…

如何使用 WebAssembly 扩展后端应用

1. WebAssembly 简介 随着互联网的发展&#xff0c;越来越多的应用借助 Javascript 转到了 Web 端&#xff0c;但人们也发现&#xff0c;随着移动互联网的兴起&#xff0c;需要把大量的应用迁移到手机端&#xff0c;随着手端的应用逻辑越来越复杂&#xff0c;Javascript 的解析…

Fastdfs V6.12.1集群部署(arm/x86均可用)

文章目录 一、 Fastdfs 介绍二、部署 信息三、步骤tracker/storage 机器的 compose 内容storage 机器的 composetracker 与 storage 启动目录层级与配置文件测试测试集群扩容与缩减注意事项 一、 Fastdfs 介绍 FastDFS 是一款高性能的分布式文件系统&#xff0c;特别适合用于存…

maven-resources-production:ratel-fast: java.lang.IndexOutOfBoundsException

Maven生产环境中遇到java.lang.IndexOutOfBoundsException的问题&#xff0c;尝试了重启电脑、重启IDEA等常规方法无效&#xff0c;最终通过直接重建工程解决了问题。 Rebuild Project 再启动OK

1. JasperSoft介绍与安装

Jaspersoft介绍 Jaspersoft是一款开源的&#xff0c;强大灵活并且使用广泛的报表软件。能够展示丰富的页面内容&#xff0c;并将之转换成PDF、HTML或者XML格式&#xff0c;该库完全由Java写出&#xff0c;可以用于在各种Java应用程序&#xff0c;非常适合Java开发者用来做报表生…

知网研学 | 知网文献(CAJ+PDF)批量下载

知网文献&#xff08;CAJPDF&#xff09;批量下载 一、知网研学安装二、插件及脚本安装三、CAJ批量下载四、脚本下载及PDF批量下载浏览器取消拦截窗口 一、知网研学安装 批量下载知网文件&#xff0c;格式为es6文件&#xff0c;需使用知网研学软件打开&#xff0c;故需先安装该…

WeakAuras NES Script(lua)

WeakAuras NES Script 修星脚本字符串 脚本1&#xff1a;NES !WA:2!TMZFWXX1zDxVAs4siiRKiBN4eV(sTRKZ5Z6opYbhQQSoPtsxr(K8ENSJtS50(J3D7wV3UBF7E6hgmKOXdjKsgAvZFaPTtte0mD60XdCmmecDMKruyykDcplAZiGPfWtSsag6myGuOuq89EVDV9wPvKeGBM7U99EFVVVV33VFFB8Z2TJ8azYMlZj7Ur3QDR(…

[数据结构] 链表

目录 1.链表的基本概念 2.链表的实现 -- 节点的构造和链接 节点如何构造? 如何将链表关联起来? 3.链表的方法(功能) 1).display() -- 链表的遍历 2).size() -- 求链表的长度 3).addFirst(int val) -- 头插法 4).addLast(int val) -- 尾插法 5).addIndex -- 在任意位置…

springmvc的拦截器,全局异常处理和文件上传

拦截器: 拦截不符合规则的&#xff0c;放行符合规则的。 等价于过滤器。 拦截器只拦截controller层API接口。 如何定义拦截器。 定义一个类并实现拦截器接口 public class MyInterceptor implements HandlerInterceptor {public boolean preHandle(HttpServletRequest reque…