深度学习基础--ResNet网络的讲解,ResNet50的复现(pytorch)以及用复现的ResNet50做鸟类图像分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 如果说最经典的神经网络,ResNet肯定是一个,这篇文章是本人学习ResNet的学习笔记,并且用pytorch复现了ResNet50,后面用它做了一个鸟类图像分类demo
  • 欢迎收藏 + 关注,本人将会持续更新

文章目录

  • ResNet网络讲解
    • 什么是ResNet?
    • ResNet神经网络突出点
    • 为什么采用残差连接
      • 模型退化、梯度消失、梯度爆炸
      • 解决方法
    • 残差网络
  • ResNet-50复现
    • 1、导入数据
      • 1、导入库
      • 2、查看数据信息和导入数据
      • 3、展示数据
      • 4、数据导入
      • 5、数据划分
      • 6、动态加载数据
    • 2、构建ResNet-50网络
    • 3、模型训练
      • 1、构建训练集
      • 2、构建测试集
      • 3、设置超参数
    • 4、模型训练
    • 5、结果可视化
  • 参考资料

ResNet网络讲解

什么是ResNet?

ResNet网络是CNN的经典网络架构,是有大神何凯明提出的,主要为了解决随着网络的加深而引起的“ 退化 ”问题,主要用于图像分类。

可以说在如今的CV领域里面,大部分网络结构都有参考ResNet网络思想,无论是在图像分类、目标检测、图像识别上,甚至在Transformer网络模型中,也融合了ResNet网络的思想。

ResNet神经网络突出点

  • 网络结构超过1000层
    • ❔ ❔ 超过1000层网络结构不是很容易么? 小编在学习深度学习的时候,曾经遇到过这样一个问题,有时候加深网络结构,反而在准确率、损失率上更差,这种现象称为模型“ 退化 ”现象,而ResNet的残差连接可以保证下一层的输出不会比输入差,从而可以加深网络结构。
  • 提出残差模块(residual):这个是ResNet的核心;
  • 采用大量的归一化在卷积层与激活函数之间.

为什么采用残差连接

模型退化、梯度消失、梯度爆炸

  • 👉 模型退化:指随着网络层数的加深,其效果出现下降趋势,不如层数少的情况。如论文中图示,56层效果不如20层效果;

在这里插入图片描述

  • 👉 梯度消失:这个是指随着网络层数的增加,反向传播,梯度更新的时候可能会造成前面几层的梯度很小、接近于0,这就会导致权重的更新会特别慢,效率低下。
  • 👉 梯度爆炸:指随着网络层数的增加,在反向传播的时候,梯度变得非常大,从而在更新权重的时候,权重值发生大幅度变化,这可能导致网络不稳定,甚至是无法收敛

解决方法

  • 梯度消失、梯度爆炸:在数据预处理和网络层之间加入:BN层(Batch Normalization),从而对数据进行归一化
  • 模型退化:采用残差连接,如论文图,随着网络层数的增加,损失率更低了。

在这里插入图片描述

残差网络

在讲述前,这里先讲述一下恒等映射的概念:

  • 恒等映射核心是复制,就是复制网络层,什么也不干。

可以这么理解:假设在一种网络A的后面添加几层形成新的网络B,如果A的输出经过新的层级变成B的输出没有发送变化,那么就可以说网络A和网络B的错误率是相等的,这样就确保了加深的网络层不会比之前的网络层效果差。


resent网络说明了,更深的网络结构可以有更好的效果,而解决这个的核心就是残差连接,网络结果如图所示:

在这里插入图片描述

上图就是何凯明提出的残差结构,这种结构实现了恒等映射,网络层的输出由两大模块组成:

  • 其一:正常的卷积层;
  • 其二:有一个分支输出到连接上,这个输出值就是输入的值;

最终结果就是:卷积层输出+分支输出,数学公式如下:

在这里插入图片描述

其中F(x)是卷积层的输出,x是分支的输入值。

极端情况:F(x)的网络层中,所有参数都为0,那么H(x)就是恒等映射。这样就确保了最后的错误率不会因为网络层的增加而导致变大


在ResNet中有两个不同的ResNet模块,如图所示:

在这里插入图片描述

左边

  • 有两层残差单元,输出通道都是3*3
  • 使用情况:用于较浅的ResNet网络。

右边

  • 三层残差单元,称为blottlenck模块,作用是:现用1*1卷积进行降维,后用3*3卷积进行特征特权,最后用1*1卷积恢复原来的维度,这个可以很好的减少参数个数,用于较深的神经网络

下图参考一个csdn大神笔记图

在这里插入图片描述

CNN参数计算公式:卷积核尺寸 * 卷积核速度 * 卷积核组数 == 卷积核尺寸 * 输入特征矩阵深度 * 输出矩阵深度。

ResNet经典的网络结构有ResNet-50,ResNet-101等,本文将用pytorch复现ResNet-50,并用其做一个简单的实验–鸟类图片分类

ResNet-50网络结果如下:

在这里插入图片描述

ResNet-50复现

1、导入数据

1、导入库

import torch  
import torch.nn as nn
import torchvision 
import numpy as np 
import os, PIL, pathlib # 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"device 
'cuda'

2、查看数据信息和导入数据

数据目录有两个文件:一个数据文件,一个权重。

data_dir = "./data/bird_photos"data_dir = pathlib.Path(data_dir)# 类别数量
classnames = [str(path).split('/')[0] for path in os.listdir(data_dir)]classnames
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

3、展示数据

import matplotlib.pylab as plt  
from PIL import Image # 获取文件名称
data_path_name = "./data/bird_photos/Bananaquit/"
data_path_list = [f for f in os.listdir(data_path_name) if f.endswith(('jpg', 'png'))]# 创建画板
fig, axes = plt.subplots(2, 8, figsize=(16, 6))for ax, img_file in zip(axes.flat, data_path_list):path_name = os.path.join(data_path_name, img_file)img = Image.open(path_name) # 打开# 显示ax.imshow(img)ax.axis('off')plt.show()


在这里插入图片描述

4、数据导入

from torchvision import transforms, datasets # 数据统一格式
img_height = 224
img_width = 224 data_tranforms = transforms.Compose([transforms.Resize([img_height, img_width]),transforms.ToTensor(),transforms.Normalize(   # 归一化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
])# 加载所有数据
total_data = datasets.ImageFolder(root="./data/bird_photos", transform=data_tranforms)

5、数据划分

# 大小 8 : 2
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])

6、动态加载数据

batch_size = 32 train_dl = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_data,batch_size=batch_size,shuffle=False
)
# 查看数据维度
for data, labels in train_dl:print("data shape[N, C, H, W]: ", data.shape)print("labels: ", labels)break
data shape[N, C, H, W]:  torch.Size([32, 3, 224, 224])
labels:  tensor([0, 1, 0, 1, 2, 1, 1, 0, 2, 2, 1, 2, 1, 3, 1, 2, 2, 2, 2, 1, 2, 1, 2, 2,0, 3, 3, 3, 3, 2, 3, 3])

2、构建ResNet-50网络

在这里插入图片描述

import torch.nn.functional as F# 定义残差模块一,这个用于处理输入和输出通道一样的情况
'''  
卷积核大小:1       3       1
核心特点:尺寸不变:输入和输出的尺寸保持一致。 没有下采样:没有使用步长大于1的卷积操作,因此没有改变特征图的空间尺寸
'''
class Identity_block(nn.Module):def __init__(self, in_channels, kernel_size, filters):super(Identity_block, self).__init__()# 输出通道filter1, filter2, filter3 = filters# 卷积层一self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=1)self.bn1 = nn.BatchNorm2d(filter1)# 卷积层2self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1)   # 通过卷积输入输出公式发现,padding=1,可以保证输入和输出尺寸相同self.bn2 = nn.BatchNorm2d(filter2)# 卷积层3self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)self.bn3 = nn.BatchNorm2d(filter3)def forward(self, x):# 记录原始值xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))# 残差连接,输入、输出维度不变x += xxx = F.relu(x)return x # 定义卷积模块二:用于处理输入和输出不一样的情况
'''  
* 卷积核还是:1 3 1
* stride=2
* 这里的分支是采用一个Conv2D,和一个归一化BN层,也是为了处理数据维度吧, 这种维度的变化,可以用ai举例子核心特点:尺寸变化,stride=2降维
'''
class ConvBlock(nn.Module):def __init__(self, in_channels, kernel_size, filters, stride=2):super(ConvBlock, self).__init__()filter1, filter2, filter3= filters# 卷积层1self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride)self.bn1 = nn.BatchNorm2d(filter1)# 卷积2self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1) # 需要维持维度不变self.bn2 = nn.BatchNorm2d(filter2)# 卷积3self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)  # stride = 1,维持通道不变self.bn3 = nn.BatchNorm2d(filter3)# 用于匹配维度的shortcut卷积,这个就是上面Identity_block的x分支self.shortcut = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride)self.shortcut_bn = nn.BatchNorm2d(filter3)def forward(self, x):xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))temp = self.shortcut_bn(self.shortcut(xx))x += tempx = F.relu(x)return x # 定义ResNet50
class ResNet50(nn.Module):def __init__(self, classes):   # 类别数量super().__init__()# 头顶self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一部分self.part1_1 = ConvBlock(64, 3, [64, 64, 256], stride=1)self.part1_2 = Identity_block(256, 3, [64, 64, 256])self.part1_3 = Identity_block(256, 3, [64, 64, 256])# 第二部分self.part2_1 = ConvBlock(256, 3, [128, 128, 512])self.part2_2 = Identity_block(512, 3, [128, 128, 512])self.part2_3 = Identity_block(512, 3, [128, 128, 512])self.part2_4 = Identity_block(512, 3, [128, 128, 512])# 第三部分self.part3_1 = ConvBlock(512, 3, [256, 256, 1024])self.part3_2 = Identity_block(1024, 3, [256, 256, 1024])self.part3_3 = Identity_block(1024, 3, [256, 256, 1024])self.part3_4 = Identity_block(1024, 3, [256, 256, 1024])self.part3_5 = Identity_block(1024, 3, [256, 256, 1024])self.part3_6 = Identity_block(1024, 3, [256, 256, 1024])# 第四部分self.part4_1 = ConvBlock(1024, 3, [512, 512, 2048])self.part4_2 = Identity_block(2048, 3, [512, 512, 2048])self.part4_3 = Identity_block(2048, 3, [512, 512, 2048])# 平均池化self.avg_pool = nn.AvgPool2d(kernel_size=7)# 全连接self.fn1 = nn.Linear(2048, classes)def forward(self, x):# 头部x = F.relu(self.bn1(self.conv1(x)))x = self.max_pool(x)x = self.part1_1(x)x = self.part1_2(x)x = self.part1_3(x)x = self.part2_1(x)x = self.part2_2(x)x = self.part2_3(x)x = self.part2_4(x)x = self.part3_1(x)x = self.part3_2(x)x = self.part3_3(x)x = self.part3_4(x)x = self.part3_5(x)x = self.part3_6(x)x = self.part4_1(x)x = self.part4_2(x)x = self.part4_3(x)x = self.avg_pool(x)x = x.view(x.size(0), -1)  # 扁平化x = self.fn1(x)return x model = ResNet50(classes=len(classnames)).to(device)model
ResNet50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(part1_1): ConvBlock((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_2): Identity_block((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_3): Identity_block((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_1): ConvBlock((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_2): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_3): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_4): Identity_block((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_1): ConvBlock((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_2): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_3): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_4): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_5): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_6): Identity_block((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_1): ConvBlock((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_2): Identity_block((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_3): Identity_block((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(avg_pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fn1): Linear(in_features=2048, out_features=4, bias=True)
)
model(torch.randn(32, 3, 224, 224).to(device)).shape
torch.Size([32, 4])

3、模型训练

1、构建训练集

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)batch_size = len(dataloader)train_acc, train_loss = 0, 0 for X, y in dataloader:X, y = X.to(device), y.to(device)# 训练pred = model(X)loss = loss_fn(pred, y)# 梯度下降法optimizer.zero_grad()loss.backward()optimizer.step()# 记录train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_acc /= sizetrain_loss /= batch_sizereturn train_acc, train_loss

2、构建测试集

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)batch_size = len(dataloader)test_acc, test_loss = 0, 0 with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= batch_sizereturn test_acc, test_loss

3、设置超参数

loss_fn = nn.CrossEntropyLoss()  # 损失函数     
learn_lr = 1e-4             # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 优化器

4、模型训练

train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 80for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 输出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")

在这里插入图片描述

5、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


在这里插入图片描述

参考资料

【深度学习】ResNet网络讲解-CSDN博客

K同学啊,训练营文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/24418.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【DeepSeek】【GPT-Academic】:DeepSeek集成到GPT-Academic(官方+第三方)

目录 1 官方deepseek 1.1 拉取学术GPT项目 1.2 安装依赖 1.3 修改配置文件中的DEEPSEEK_API_KEY 2 第三方API 2.1 修改DEEPSEEK_API_KEY 2.2 修改CUSTOM_API_KEY_PATTERM 2.3 地址重定向 2.4 修改模型参数 2.5 成功调用 2.6 尝试添加一个deepseek-r1参数 3 使用千帆…

用Golang与WebAssembly构建高性能Web应用:详解`syscall/js`包

用Golang与WebAssembly构建高性能Web应用:详解syscall/js包 引言为什么选择syscall/js包?适用场景 syscall/js包概述syscall/js包的核心概念1. js.Global2. js.Value3. js.Func4. js.Null 和 js.Undefined syscall/js包在WebAssembly中的位置 环境配置与…

本地部署轻量级web开发框架Flask并实现无公网ip远程访问开发界面

文章目录 1. 安装部署Flask2. 安装Cpolar内网穿透3. 配置Flask的web界面公网访问地址4. 公网远程访问Flask的web界面 本篇文章主要讲解如何在本地安装Flask,以及如何将其web界面发布到公网进行远程访问。 Flask是目前十分流行的web框架,采用Python编程…

ChatGPT背后的理论基础:从预训练到微调的深度解析

友情提示:本文内容由银河易创(https://ai.eaigx.com)AI创作平台GPT-4o-mini模型生成,仅供参考。请根据具体情况和需求进行适当的调整和验证。 随着人工智能特别是自然语言处理技术的飞速发展,ChatGPT作为一种强大的对话…

2025面试Go真题第一场

前几天参加了一场面试,GoLang 后端工程师,他们直接给了我 10 道题,我留了一个截图。 在看答案之前,你可以先简单做一下,下面我会对每个题目做一个说明。 文章目录 1、golang map 是否并发安全?2、协程泄漏的原因可能是…

网络安全第三次练习

一、实验拓扑 二、实验要求 配置真实DNS服务信息,创建虚拟服务,配置DNS透明代理功能 三、需求分析 1.创建用户并配置认证策略 2.安全策略划分接口 3.ip与策略配置 四、实验步骤 1.划分安全策略接口 2.创建用户并进行策略认证 3.配置安全策略 4.NAT配…

FTP 实验(ENSP模拟器实现)

目录 FTP 概述 FTP实验 FTP的报文交互 FTP 概述 FTP(File Transfer Protocol,文件传输协议)是一种用于在网络上进行文件传输的标准协议。它允许用户在两台计算机之间上传和下载文件。 1、FTP采用客户端-服务器模型,客户端通过…

Elasticsearch:使用经过训练的 ML 模型理解稀疏向量嵌入

作者:来自 Elastic Dai Sugimori 了解稀疏向量嵌入,理解它们的作用/含义,以及如何使用它们实现语义搜索。 Elasticsearch 提供语义搜索功能,允许用户使用自然语言进行查询并检索相关信息。为此,目标文档和查询必须首先…

Java进阶(vue基础)

目录 1.vue简单入门 ?1.1.创建一个vue程序 1.2.使用Component模板(组件) 1.3.引入AXOIS ?1.4.vue的Methods(方法) 和?compoted(计算) 1.5.插槽slot 1.6.创建自定义事件? 2.Vue脚手架安装? 3.Element-UI的…

密码学基础

第1节 密码学概述 密码是人类在信息活动中的一项伟大发明,是保护秘密信息的工具。它诞生于公元前两千余年的埃及,迄今已有四千多年的历史。在出现年代有可查证记录的科学技术中,密码是历史最为悠久的科学技术之一。 百度百科里对密码的解释&…

Java入门级小案例:网页版简易计算器

网页版简易计算器 目录 网页版简易计算器需求&#xff1a;代码实现&#xff1a;效果显示 需求&#xff1a; 用HTML、CSS、JS进行书写一个具备一定功能的简易计算器。 代码实现&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta cha…

【Uniapp-Vue3】导入uni-id用户体系

在uniapp官网的uniCloud中下载uni-id用户体系 或者直接进入加载&#xff0c;下载地址&#xff1a;uni-id-pages - DCloud 插件市场 进入以后下载插件&#xff0c;打开HbuilderX 选中项目&#xff0c;点击确定 点击跳过 点击合并 右键uniCloud文件夹下的database文件夹&#x…

Python 入门教程(2)搭建环境 | 2.3、VSCode配置Python开发环境

文章目录 一、VSCode配置Python开发环境1、软件安装2、安装Python插件3、配置Python环境4、包管理5、调试程序 前言 Visual Studio Code&#xff08;简称VSCode&#xff09;以其强大的功能和灵活的扩展性&#xff0c;成为了许多开发者的首选。本文将详细介绍如何在VSCode中配置…

Spring Boot电影评论网站系统设计与实现

随着互联网和娱乐产业的发展&#xff0c;电影评论网站逐渐成为人们分享观影体验、交流影评的重要平台。本文将介绍一个基于Spring Boot框架开发的电影评论网站系统的功能设计与实现方案。 功能模块概述 该电影评论网站系统分为管理员模块和用户模块两大核心部分&#xff0c;以…

RT-Thread+STM32L475VET6——TF 卡文件系统

文章目录 前言一、板载资源二、具体步骤1.打开CubeMX进行USB配置1.1 使用外部高速时钟&#xff0c;并修改时钟树1.2 打开SPI1&#xff0c;参数默认即可(SPI根据自己需求调整&#xff09;1.3 打开串口&#xff0c;参数默认1.4 生成工程 2.配置SPI2.1 打开SPI驱动2.2 声明使用SPI…

LabVIEW形状误差测量系统

在机械制造领域&#xff0c;形状与位置公差&#xff08;GD&T&#xff09;直接影响装配精度与产品寿命。国内中小型机加工企业因形状误差导致的返工率高达12%-18%。传统测量方式存在以下三大痛点&#xff1a; ​ 设备局限&#xff1a;机械式千分表需人工读数&#xff0c;精度…

【c语言】字符函数和字符串函数(1)

一、字符分类函数 c语言中有部分函数是专门做字符分类的&#xff0c;也就是一个字符是属于什么类型的字符&#xff0c;这些函 数的使用要包含一个头文件ctype.h中。 其具体如下图所示&#xff1a; 这些函数的使用方式都类似&#xff0c;下面我们通过一个函数来看其…

【Python LeetCode 专题】动态规划

斐波那契类型70. 爬楼梯746. 使用最小花费爬楼梯198. 打家劫舍740. 删除并获得点数矩阵62. 不同路径方法一:二维 DP方法二:递归(`@cache`)64. 最小路径和63. 不同路径 II120. 三角形最小路径和221. 最大正方形字符串139. 单词拆分5. 最长回文子串516. 最长回文子序列72. 编…

Linux相关知识(文件系统、目录树、权限管理)和Shell相关知识(字符串、数组)

仅供自学&#xff0c;请去支持javaGuide原版书籍。 1.Linux 1.1.概述 Linux是一种类Unix系统。 严格来讲&#xff0c;Linux 这个词本身只表示 Linux内核&#xff0c;单独的 Linux 内核并不能成为一个可以正常工作的操作系统。所以&#xff0c;就有了各种 Linux 发行版&#…

第九节: Vue 3 中的 provide 与 inject:优雅的跨组件通信

文章目录 前言什么是 provide 和 inject&#xff1f;provide 的基本使用inject 的基本使用provide 提供响应式数据数据provide 提供修改数据的方法provide 提供只读响应数据provide 使用symbol作为注入名inject 默认值总结 前言 在 Vue 3 中&#xff0c;provide 和 inject 是一…