ResNet-RS 乳腺癌识别

一、模型结构

1.1 模型思路

ResNet-RS是一种改进的ResNet架构,它在2021年由谷歌大脑和UC Berkeley的研究者们提出。ResNet-RS的提出基于对现有ResNet架构的深入研究,研究者们重新审视了ResNet的结构、训练方法以及缩放策略,并提出了一些改进措施。这些改进包括:
  1. 在不改变模型结构的前提下,通过实验验证了不同的正则化方法及其组合的作用,得到了能提升性能的正则化策略。
  2. 提出了简单、高效的缩放策略,包括在可能发生过拟合的情况下优先缩放模型深度,以及更慢地缩放输入分辨率。
  3. 将上述正则化策略和缩放策略应用到ResNet中,提出了ResNet-RS系列,其性能全面超越了EfficientNet系列。

ResNet-RS 是:改进的缩放策略、改进的训练方法、ResNet-D 修改(He 等人,2018 年)和 SqueezeExcitation 模块(Hu 等人,2018 年)的组合

原文如下:

Revisiting ResNets- Improved Training and Scaling Strategies.pdf

1.2 ResNet-RS模型改进

1.改进的训练过程: ResNet-RS并没有在结构上进行大幅度的修改,而是专注于训练过程的优化。例如,它采用了一系列更现代的训练技巧,如学习率调度、数据增强和正则化等。这些改进使得ResNet-RS可以在不增加大量计算成本的情况下,比经典的ResNet取得更好的性能。
2. 调整的架构比例: 论文中提出了一些结构上的调整,例如通过调整卷积层的通道数和层数比例,使得网络在更深的层次上能更有效地传递梯度,减少了梯度消失问题。
3. 正则化方法: ResNet-RS引入了多种正则化方法,如Dropout、Label Smoothing、Stochastic Depth和CutMix。这些正则化技巧有助于缓解过拟合,并提高模型的泛化能力。
4. 增强的数据增强技术: ResNet-RS使用了现代的数据增强技术,如AutoAugment和RandAugment。这些方法在训练过程中生成了更多的高质量样本,进一步提高了模型的鲁棒性。
5. 优化的训练超参数: 该模型通过使用预热学习率调度、Cosine Annealing学习率策略,以及精细调优的超参数(如批量大小和权重衰减)来优化训练过程,从而使模型更快收敛并达到更高的精度。

1.3 ResNet-RS 架构细节

1.3.1 ResNet-D架构

ResNetRS是在ResNet-D架构上面的改进,ResNet-D架构的结构如下:

注意,残差边上多了个池化操作。

ResNet-D (He et al., 2018) 结合了对原始 ResNet 架构的以下四个调整。

  • 首先,在 InceptionV3 (Szegedy et al., 2016) 中首次提出,将干中的 7×7 卷积替换为三个较小的 3×3 卷积。
  • 其次,在下采样块的残差路径中切换前两个卷积的步幅大小。
  • 第三,将下采样块的skip connection路径中的 stride-2 1×1 卷积替换为 stride-2 2×2 平均池化,然后是non-strided 1×1 卷积。
  • 第四,去除 stride2 3×3 max pool layer,下采样发生在下一个bottleneck block的第一个 3×3 卷积中。

1.3.2 ResNet-RS架构

下图显示了使用的所有 ResNet 深度的块布局。 ResNet-50 到 ResNet-200 使用 He 等人的标准块配置。 ResNet-270 及更高版本主要扩展 c3 和 c4 中的块数,作者尝试保持它们的比例大致恒定。作者根据经验发现,在较低阶段添加块会限制过度拟合,因为较低层中的块具有显着较少的参数,即使所有块具有相同数量的 FLOP。

与经典ResNet相比,ResNet-RS不仅在精度上表现更好,还在计算效率方面进行了优化。在图像分类任务(如ImageNet)上,ResNet-RS能够在同样或更少的计算资源下获得比ResNet更高的精度。

二、 数据准备

2.1. 设置GPU

  • 如果设备上支持GPU就使用GPU,否则使用CPU
  • MacOS设备使用的是MPS
import torch
import torch.nn as nn# 设置硬件设备,如果有GPU则使用,没有则使用cpu
def get_device():if torch.cuda.is_available():  # cudadevice = "cuda"elif torch.backends.mps.is_available():  # macosdevice = "mps"else:   # cpu"cpu"return torch.device(devicedevice = get_device()
device
device(type='cuda')

2.2. 导入数据

读取数据并进行查看:
import os,PIL,random,pathlibdata_dir = "./data/26-data/"
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("/")[-1] for path in data_paths]
classeNames
['malignant', 'benign']

2.3 数据可视化

2.4 划分数据集

total_datadir = './data/26-data/'# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([160, 160]),  # 将输入图片resize成统一尺寸transforms.RandAugment(magnitude=10), transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)

使用torch.utils.data.random_split()torch.utils.data.DataLoader()创建训练和测试数据集:

train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_sizetrain_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)

查看数据集shape

for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break
Shape of X [N, C, H, W]:  torch.Size([32, 3, 160, 160])
Shape of y:  torch.Size([32]) torch.int64

三、Pytorch实现ResNet-RS模型

3.1 定义Stem层

ResNet-D型Stem块,StemBlock 作为网络的起始层,用3个卷积层对输入图像进行初步处理。
class StemBlock(nn.Module):def __init__(self, channel_in, channel_out):super(StemBlock, self).__init__()channel = int(channel_out / 2)  # 输出通道数的一半# 定义初始层的结构self.stem = nn.Sequential(nn.Conv2d(channel_in, channel, kernel_size=(3, 3), stride=2, padding=1, bias=False),  # 第一层卷积nn.BatchNorm2d(channel),  # 批归一化nn.ReLU(inplace=True),  # ReLU 激活函数nn.Conv2d(channel, channel, kernel_size=(3, 3), stride=1, padding=1, bias=False),  # 第二层卷积nn.BatchNorm2d(channel),  # 批归一化nn.ReLU(inplace=True),  # ReLU 激活函数nn.Conv2d(channel, channel_out, kernel_size=(3, 3), stride=1, padding=1, bias=False)  # 第三层卷积)self.init_weights()  # 初始化权重def init_weights(self):# 初始化卷积层和批归一化层的权重for _, module in self.named_modules():if isinstance(module, nn.Conv2d):nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')if isinstance(module, nn.BatchNorm2d):nn.init.ones_(module.weight)  # 批归一化权重初始化为1nn.init.zeros_(module.bias)    # 批归一化偏置初始化为0def forward(self, x):return self.stem(x)  # 前向传播,返回经过初始层的结果

3.2 定义残差单元

每个残差块包含多个卷积层,使用1x1和3x3卷积实现降维和特征提取
class Block(nn.Module):def __init__(self, channel_in, channel_out, stride, identity):super(Block, self).__init__()channel = int(channel_out / 4)  # 输出通道数的1/4self.se = SEBlock(channel_in)  # Squeeze-and-Excitation模块# 1x1 卷积层self.bn1 = nn.BatchNorm2d(channel_in)  # 批归一化self.relu1 = nn.ReLU(inplace=True)  # ReLU 激活self.conv1 = nn.Conv2d(channel_in, channel, kernel_size=(1, 1), bias=False)  # 1x1 卷积# 3x3 卷积层self.bn2 = nn.BatchNorm2d(channel)  # 批归一化self.relu2 = nn.ReLU(inplace=True)  # ReLU 激活self.conv2 = nn.Conv2d(channel, channel, kernel_size=(3, 3), stride=stride, padding=1, bias=False)  # 3x3 卷积# 1x1 卷积层self.bn3 = nn.BatchNorm2d(channel)  # 批归一化self.drop_out = DropPath(drop_prob=0.)  # 随机丢弃部分路径self.conv3 = nn.Conv2d(channel, channel_out, kernel_size=(1, 1), bias=False)  # 1x1 卷积# 跳跃连接self.identity = identity  # 是否使用跳跃连接self.downsample = DownSample(channel_in, channel_out, stride)  # 用于匹配特征图尺寸self.init_weights()  # 初始化权重def init_weights(self):# 初始化卷积层和批归一化层的权重for _, module in self.named_modules():if isinstance(module, nn.Conv2d):nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')if isinstance(module, nn.BatchNorm2d):nn.init.ones_(module.weight)  # 批归一化权重初始化为1nn.init.zeros_(module.bias)    # 批归一化偏置初始化为0def forward(self, x):h = self.se(x)  # 通过SEBlock进行通道加权h = self.bn1(h)  # 批归一化h = self.relu1(h)  # ReLU 激活h = self.conv1(h)  # 1x1 卷积h = self.bn2(h)  # 批归一化h = self.relu2(h)  # ReLU 激活h = self.conv2(h)  # 3x3 卷积h = self.bn3(h)  # 批归一化h = self.drop_out(h)  # 随机丢弃部分路径h = self.conv3(h)  # 1x1 卷积shortcut = self.downsample(x) if self.identity else x  # 跳跃连接y = h + shortcut  # 残差连接return y  # 返回输出

3.3 定义SEBlock单元模块

SEBlock(Squeeze-and-Excitation)模块,用于学习输入特征通道的权重。
class SEBlock(nn.Module):def __init__(self, channel, ratio=0.25):super(SEBlock, self).__init__()reduced_channel = int(channel * ratio)  # 降维通道数self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 自适应平均池化self.fc = nn.Sequential(  # 全连接层nn.Linear(channel, reduced_channel, bias=False),  # 降维nn.ReLU(inplace=True),  # ReLU 激活nn.Linear(reduced_channel, channel, bias=False),  # 恢复维度nn.Sigmoid()  # 激活函数)def forward(self, x):b, c, _, _ = x.size()  # 获取批大小和通道数y = self.avg_pool(x).view(b, c)  # 平均池化y = self.fc(y).view(b, c, 1, 1)  # 经过全连接层return x * y.expand_as(x)  # 加权输入特征

3.4 定义DropPath模块

DropPath 是随机丢弃部分路径(stochastic depth)的实现
class DropPath(nn.Module):def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_prob  # 丢弃概率def forward(self, x):# 如果不丢弃或不在训练模式下,返回输入if self.drop_prob is None or self.drop_prob == 0 or not self.training:return xkeep_prob = 1 - self.drop_prob  # 保留概率shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)  # 形状rand_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)  # 随机张量rand_tensor = rand_tensor.floor_()  # 取整out = x.div(keep_prob) * rand_tensor  # 丢弃路径的操作return out

3.5 定义DownSample

DownSample:用于在特征图尺寸变化时对跳跃连接进行匹配
class DownSample(nn.Module):def __init__(self, channel_in, channel_out, stride):super(DownSample, self).__init__()if stride == 1:avg_pool = nn.Identity()  # 如果 stride 为 1,不进行下采样else:avg_pool = nn.AvgPool2d(kernel_size=2, stride=stride)  # 否则使用平均池化# 定义下采样的结构self.downsample = nn.Sequential(avg_pool,nn.Conv2d(channel_in, channel_out, kernel_size=(1, 1), bias=False)  # 1x1 卷积)def forward(self, x):return self.downsample(x)  # 前向传播

3.6 定义全局池化层

作为全局平均池化层,将空间维度池化成1x1
# 作为全局平均池化层,将空间维度池化成1x1
class GlobalAvgPool2d(nn.Module):def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):# 平均池化到 1x1,并展平为 (batch_size, num_channels)return F.avg_pool2d(x, kernel_size=x.size()[2:]).view(-1, x.size(1))

3.7 搭建 ResNet-RS 网络

class ResNetRS(nn.Module):def __init__(self, output_dim):super(ResNet50rs, self).__init__()self.stem = StemBlock(channel_in=3, channel_out=64)# Block 1self.id1 = Block(channel_in=64, channel_out=256, stride=2, identity=True)self.block1 = nn.ModuleList([Block(channel_in=256, channel_out=256, stride=1, identity=False) for _ in range(2)])# Block 2self.id2 = Block(channel_in=256, channel_out=512, stride=2, identity=True)self.block2 = nn.ModuleList([Block(channel_in=512, channel_out=512, stride=1, identity=False) for _ in range(3)])# Block 3self.id3 = Block(channel_in=512, channel_out=1024, stride=2, identity=True)self.block3 = nn.ModuleList([Block(channel_in=1024, channel_out=1024, stride=1, identity=False) for _ in range(5)])# Block 4self.id4 = Block(channel_in=1024, channel_out=2048, stride=2, identity=True)self.block4 = nn.ModuleList([Block(channel_in=2048, channel_out=2048, stride=1, identity=False) for _ in range(2)])self.avg_pool = GlobalAvgPool2d()  # 全局平均池化层self.dropout = nn.Dropout(p=0.25)self.fc = nn.Linear(2048, output_dim, bias=False)def forward(self, x):h = self.stem(x)h = self.id1(h)for block in self.block1:h = block(h)h = self.id2(h)for block in self.block2:h = block(h)h = self.id3(h)for block in self.block3:h = block(h)h = self.id4(h)for block in self.block4:h = block(h)h = self.avg_pool(h)h = self.dropout(h)h = torch.relu(h)h = self.fc(h)y = torch.log_softmax(h, dim=-1)return y

3.8 查看模型摘要

======================================================================
Layer (type:depth-idx)                        Param #
======================================================================
ResNet50rs                                    --
├─StemBlock: 1-1                              --
│    └─Sequential: 2-1                        --
│    │    └─Conv2d: 3-1                       864
│    │    └─BatchNorm2d: 3-2                  64
│    │    └─ReLU: 3-3                         --
│    │    └─Conv2d: 3-4                       9,216
│    │    └─BatchNorm2d: 3-5                  64
│    │    └─ReLU: 3-6                         --
│    │    └─Conv2d: 3-7                       18,432
├─Block: 1-2                                  --
│    └─SEBlock: 2-2                           --
│    │    └─AdaptiveAvgPool2d: 3-8            --
│    │    └─Sequential: 3-9                   2,048
│    └─BatchNorm2d: 2-3                       128
│    └─ReLU: 2-4                              --
│    └─Conv2d: 2-5                            4,096
│    └─BatchNorm2d: 2-6                       128
│    └─ReLU: 2-7                              --
│    └─Conv2d: 2-8                            36,864
│    └─BatchNorm2d: 2-9                       128
│    └─DropPath: 2-10                         --
│    └─Conv2d: 2-11                           16,384
│    └─DownSample: 2-12                       --
│    │    └─Sequential: 3-10                  16,384
├─ModuleList: 1-3                             --
│    └─Block: 2-13                            --
│    │    └─SEBlock: 3-11                     32,768
│    │    └─BatchNorm2d: 3-12                 512
│    │    └─ReLU: 3-13                        --
│    │    └─Conv2d: 3-14                      16,384
│    │    └─BatchNorm2d: 3-15                 128
│    │    └─ReLU: 3-16                        --
│    │    └─Conv2d: 3-17                      36,864
│    │    └─BatchNorm2d: 3-18                 128
│    │    └─DropPath: 3-19                    --
│    │    └─Conv2d: 3-20                      16,384
│    │    └─DownSample: 3-21                  65,536
│    └─Block: 2-14                            --
│    │    └─SEBlock: 3-22                     32,768
│    │    └─BatchNorm2d: 3-23                 512
│    │    └─ReLU: 3-24                        --
│    │    └─Conv2d: 3-25                      16,384
│    │    └─BatchNorm2d: 3-26                 128
│    │    └─ReLU: 3-27                        --
│    │    └─Conv2d: 3-28                      36,864
│    │    └─BatchNorm2d: 3-29                 128
│    │    └─DropPath: 3-30                    --
│    │    └─Conv2d: 3-31                      16,384
│    │    └─DownSample: 3-32                  65,536
├─Block: 1-4                                  --
│    └─SEBlock: 2-15                          --
│    │    └─AdaptiveAvgPool2d: 3-33           --
│    │    └─Sequential: 3-34                  32,768
│    └─BatchNorm2d: 2-16                      512
│    └─ReLU: 2-17                             --
│    └─Conv2d: 2-18                           32,768
│    └─BatchNorm2d: 2-19                      256
│    └─ReLU: 2-20                             --
│    └─Conv2d: 2-21                           147,456
│    └─BatchNorm2d: 2-22                      256
│    └─DropPath: 2-23                         --
│    └─Conv2d: 2-24                           65,536
│    └─DownSample: 2-25                       --
│    │    └─Sequential: 3-35                  131,072
├─ModuleList: 1-5                             --
│    └─Block: 2-26                            --
│    │    └─SEBlock: 3-36                     131,072
│    │    └─BatchNorm2d: 3-37                 1,024
│    │    └─ReLU: 3-38                        --
│    │    └─Conv2d: 3-39                      65,536
│    │    └─BatchNorm2d: 3-40                 256
│    │    └─ReLU: 3-41                        --
│    │    └─Conv2d: 3-42                      147,456
│    │    └─BatchNorm2d: 3-43                 256
│    │    └─DropPath: 3-44                    --
│    │    └─Conv2d: 3-45                      65,536
│    │    └─DownSample: 3-46                  262,144
│    └─Block: 2-27                            --
│    │    └─SEBlock: 3-47                     131,072
│    │    └─BatchNorm2d: 3-48                 1,024
│    │    └─ReLU: 3-49                        --
│    │    └─Conv2d: 3-50                      65,536
│    │    └─BatchNorm2d: 3-51                 256
│    │    └─ReLU: 3-52                        --
│    │    └─Conv2d: 3-53                      147,456
│    │    └─BatchNorm2d: 3-54                 256
│    │    └─DropPath: 3-55                    --
│    │    └─Conv2d: 3-56                      65,536
│    │    └─DownSample: 3-57                  262,144
│    └─Block: 2-28                            --
│    │    └─SEBlock: 3-58                     131,072
│    │    └─BatchNorm2d: 3-59                 1,024
│    │    └─ReLU: 3-60                        --
│    │    └─Conv2d: 3-61                      65,536
│    │    └─BatchNorm2d: 3-62                 256
│    │    └─ReLU: 3-63                        --
│    │    └─Conv2d: 3-64                      147,456
│    │    └─BatchNorm2d: 3-65                 256
│    │    └─DropPath: 3-66                    --
│    │    └─Conv2d: 3-67                      65,536
│    │    └─DownSample: 3-68                  262,144
├─Block: 1-6                                  --
│    └─SEBlock: 2-29                          --
│    │    └─AdaptiveAvgPool2d: 3-69           --
│    │    └─Sequential: 3-70                  131,072
│    └─BatchNorm2d: 2-30                      1,024
│    └─ReLU: 2-31                             --
│    └─Conv2d: 2-32                           131,072
│    └─BatchNorm2d: 2-33                      512
│    └─ReLU: 2-34                             --
│    └─Conv2d: 2-35                           589,824
│    └─BatchNorm2d: 2-36                      512
│    └─DropPath: 2-37                         --
│    └─Conv2d: 2-38                           262,144
│    └─DownSample: 2-39                       --
│    │    └─Sequential: 3-71                  524,288
├─ModuleList: 1-7                             --
│    └─Block: 2-40                            --
│    │    └─SEBlock: 3-72                     524,288
│    │    └─BatchNorm2d: 3-73                 2,048
│    │    └─ReLU: 3-74                        --
│    │    └─Conv2d: 3-75                      262,144
│    │    └─BatchNorm2d: 3-76                 512
│    │    └─ReLU: 3-77                        --
│    │    └─Conv2d: 3-78                      589,824
│    │    └─BatchNorm2d: 3-79                 512
│    │    └─DropPath: 3-80                    --
│    │    └─Conv2d: 3-81                      262,144
│    │    └─DownSample: 3-82                  1,048,576
│    └─Block: 2-41                            --
│    │    └─SEBlock: 3-83                     524,288
│    │    └─BatchNorm2d: 3-84                 2,048
│    │    └─ReLU: 3-85                        --
│    │    └─Conv2d: 3-86                      262,144
│    │    └─BatchNorm2d: 3-87                 512
│    │    └─ReLU: 3-88                        --
│    │    └─Conv2d: 3-89                      589,824
│    │    └─BatchNorm2d: 3-90                 512
│    │    └─DropPath: 3-91                    --
│    │    └─Conv2d: 3-92                      262,144
│    │    └─DownSample: 3-93                  1,048,576
│    └─Block: 2-42                            --
│    │    └─SEBlock: 3-94                     524,288
│    │    └─BatchNorm2d: 3-95                 2,048
│    │    └─ReLU: 3-96                        --
│    │    └─Conv2d: 3-97                      262,144
│    │    └─BatchNorm2d: 3-98                 512
│    │    └─ReLU: 3-99                        --
│    │    └─Conv2d: 3-100                     589,824
│    │    └─BatchNorm2d: 3-101                512
│    │    └─DropPath: 3-102                   --
│    │    └─Conv2d: 3-103                     262,144
│    │    └─DownSample: 3-104                 1,048,576
│    └─Block: 2-43                            --
│    │    └─SEBlock: 3-105                    524,288
│    │    └─BatchNorm2d: 3-106                2,048
│    │    └─ReLU: 3-107                       --
│    │    └─Conv2d: 3-108                     262,144
│    │    └─BatchNorm2d: 3-109                512
│    │    └─ReLU: 3-110                       --
│    │    └─Conv2d: 3-111                     589,824
│    │    └─BatchNorm2d: 3-112                512
│    │    └─DropPath: 3-113                   --
│    │    └─Conv2d: 3-114                     262,144
│    │    └─DownSample: 3-115                 1,048,576
│    └─Block: 2-44                            --
│    │    └─SEBlock: 3-116                    524,288
│    │    └─BatchNorm2d: 3-117                2,048
│    │    └─ReLU: 3-118                       --
│    │    └─Conv2d: 3-119                     262,144
│    │    └─BatchNorm2d: 3-120                512
│    │    └─ReLU: 3-121                       --
│    │    └─Conv2d: 3-122                     589,824
│    │    └─BatchNorm2d: 3-123                512
│    │    └─DropPath: 3-124                   --
│    │    └─Conv2d: 3-125                     262,144
│    │    └─DownSample: 3-126                 1,048,576
├─Block: 1-8                                  --
│    └─SEBlock: 2-45                          --
│    │    └─AdaptiveAvgPool2d: 3-127          --
│    │    └─Sequential: 3-128                 524,288
│    └─BatchNorm2d: 2-46                      2,048
│    └─ReLU: 2-47                             --
│    └─Conv2d: 2-48                           524,288
│    └─BatchNorm2d: 2-49                      1,024
│    └─ReLU: 2-50                             --
│    └─Conv2d: 2-51                           2,359,296
│    └─BatchNorm2d: 2-52                      1,024
│    └─DropPath: 2-53                         --
│    └─Conv2d: 2-54                           1,048,576
│    └─DownSample: 2-55                       --
│    │    └─Sequential: 3-129                 2,097,152
├─ModuleList: 1-9                             --
│    └─Block: 2-56                            --
│    │    └─SEBlock: 3-130                    2,097,152
│    │    └─BatchNorm2d: 3-131                4,096
│    │    └─ReLU: 3-132                       --
│    │    └─Conv2d: 3-133                     1,048,576
│    │    └─BatchNorm2d: 3-134                1,024
│    │    └─ReLU: 3-135                       --
│    │    └─Conv2d: 3-136                     2,359,296
│    │    └─BatchNorm2d: 3-137                1,024
│    │    └─DropPath: 3-138                   --
│    │    └─Conv2d: 3-139                     1,048,576
│    │    └─DownSample: 3-140                 4,194,304
│    └─Block: 2-57                            --
│    │    └─SEBlock: 3-141                    2,097,152
│    │    └─BatchNorm2d: 3-142                4,096
│    │    └─ReLU: 3-143                       --
│    │    └─Conv2d: 3-144                     1,048,576
│    │    └─BatchNorm2d: 3-145                1,024
│    │    └─ReLU: 3-146                       --
│    │    └─Conv2d: 3-147                     2,359,296
│    │    └─BatchNorm2d: 3-148                1,024
│    │    └─DropPath: 3-149                   --
│    │    └─Conv2d: 3-150                     1,048,576
│    │    └─DownSample: 3-151                 4,194,304
├─GlobalAvgPool2d: 1-10                       --
├─Dropout: 1-11                               --
├─Linear: 1-12                                4,096
======================================================================
Total params: 46,033,248
Trainable params: 46,033,248
Non-trainable params: 0
======================================================================

四、模型训练与测试

4.1 模型训练

loss_fn    = nn.NLLLoss() # 创建损失函数
learn_rate = 0.00078125 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate, momentum=0.9, weight_decay=4e-5)  # 优化器epochs     = 20
train_loss = []
train_acc  = []test_loss  = []
test_acc   = []for epoch in range(epochs):# model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)# model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:70.3%, Train_loss:0.940, Test_acc:78.2%,Test_loss:0.606
Epoch: 2, Train_acc:73.5%, Train_loss:0.634, Test_acc:79.7%,Test_loss:0.461
Epoch: 3, Train_acc:76.6%, Train_loss:0.519, Test_acc:79.5%,Test_loss:0.464
Epoch: 4, Train_acc:78.5%, Train_loss:0.488, Test_acc:80.2%,Test_loss:0.444
Epoch: 5, Train_acc:79.8%, Train_loss:0.457, Test_acc:81.2%,Test_loss:0.435
...
Epoch:17, Train_acc:84.7%, Train_loss:0.356, Test_acc:84.5%,Test_loss:0.361
Epoch:18, Train_acc:84.2%, Train_loss:0.366, Test_acc:85.5%,Test_loss:0.335
Epoch:19, Train_acc:84.7%, Train_loss:0.351, Test_acc:85.4%,Test_loss:0.338
Epoch:20, Train_acc:85.2%, Train_loss:0.351, Test_acc:87.4%,Test_loss:0.304
Done

4.3 结果可视化

4.2 模型验证

from PIL import Image classes = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_,pred = torch.max(output,1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')# 预测训练集中的某张照片
predict_one_image(image_path='./data/26-data/benign/8863_idx5_x2551_y1451_class0.png', model=model, transform=train_transforms, classes=classes)
预测结果是:benign

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/456395.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git入门操作(2)

文章目录 git入门操作(2)git diff 查看差异git diff gitignore忽略文件1.在代码仓库创建这个文件2.添加对 log 文件过滤 连接远程仓库与ssh配置远程仓库和本地仓库关联步骤分支基本操作步骤命令: 合并冲突分支合并逻辑1.新建分支 dev&#xf…

MySQL查看当前客户端连接数的方法

每当有客户端连接到 MySQL 时,MySQL 会为该连接创建一个新的线程来处理所有与该连接相关的查询和操作。所以通过查看MySQL当前的连接线程数量就可以知道有多少客户端连接到MySQL。 方法一 Threads_connected 仅显示活跃的客户端连接数 SHOW STATUS LIKE Threads_…

H7-TOOL的LUA小程序教程第15期:电压,电流,NTC热敏电阻以及4-20mA输入(2024-10-21,已经发布)

LUA脚本的好处是用户可以根据自己注册的一批API(当前TOOL已经提供了几百个函数供大家使用),实现各种小程序,不再限制Flash里面已经下载的程序,就跟手机安装APP差不多,所以在H7-TOOL里面被广泛使用&#xff…

Go语言中三个输入函数(scanf,scan,scanln)的区别

Go语言中三个输入函数(scanf,scan,scanln)的区别 在 Go 语言中,fmt 包提供了三种输入函数:Scanf、Scan 和 Scanln。这三个函数都是用于从标准输入读取数据并存储到变量中,但是它们在处理输入的方式上有所不同。下面详细解读每个函数的特点和…

网站被浏览器提示“不安全”,如何快速解决

当网站被浏览器提示“不安全”时,这通常意味着网站存在某些安全隐患,需要立即采取措施进行解决。 一、具体原因如下: 1.如果网站使用的是HTTP协议,应立即升级HTTPS。HTTPS通过使用SSL证书加密来保护数据传输,提高了网…

CSS设置层叠样式时报红(identifier expected css/selector expected css)

不规范语法 如上图所示,在一个 css 文件中添加层叠样式时报红:at-rule or selector expected,意思就是说我们的语句不符合 css 的语法书写规范,虽然不会导致启动报错并且还能达到预期的样式效果,但是对于有强迫症的同学…

养狗为什么需要宠物空气净化器?宠物空气净化器排行榜公布!

刚开始养狗时候怎么没人跟我说要买宠物空气净化器呢?那时候什么都不懂,只买了狗粮、喂食碗、狗笼、狗窝、便盆、牵引绳以及一些狗狗玩具。结果一个星期就家里就被搞得狗毛乱飞、臭味熏天。最后在养狗博主的建议下买了一台宠物空气净化器,开了…

ffmpeg视频滤镜:压缩-deflate

滤镜简述 deflate 官网链接 > https://ffmpeg.org/ffmpeg-filters.html#deflate 压缩滤镜可以降低视频的质量&#xff0c;从而减少视频的大小&#xff0c;虽然一定程度上影响了观看体验&#xff0c;但是方便传输。 滤镜使用 参数 threshold0 <int> …

函数的力量:掌握C语言的基石

目录 前言 标准库&#xff1a;C语言的百宝箱 头文件&#xff1a;库函数的藏宝图 实例分析&#xff1a;计算平方根的sqrt函数 功能描述 头文件包含的重要性 库函数文档的一般格式 自定义函数&#xff1a;释放你的编程创造力 函数的语法形式 函数的比喻 函数的举例 简化…

FreeSSl 申请免费证书,ACME实现自动化续期(https证书自动续期)

网站&#xff1a;https://freessl.cn/ 参考&#xff1a;ACME自动化快速入门 注册/登录后 1 添加域名 2 申请证书 安装acme.sh curl https://get.acme.sh | sh -s emailmyexample.com执行ACME.sh 申请证书命令 cd ~/.acme.sh/ # 直接拷贝上面步骤生成的命令 ./acme.sh …

springboot诊所就医系统-计算机毕业设计源码16883

目 录 摘要 1 绪论 1.1 研究背景 1.2选题背景及意义 1.3论文结构与章节安排 2 诊所就医系统系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分析 …

论文笔记:通用世界模型WorldDreamer

整理了WorldDreamer: Towards General World Models for Video Generation via Predicting Masked Tokens 论文的阅读笔记 背景模型实验 背景 现有的世界模型仅限于游戏或驾驶等特定场景&#xff0c;限制了它们捕捉一般世界动态环境复杂性的能力。针对这一挑战&#xff0c;本文…

雷池社区版有多个防护站点监听在同一个端口上,匹配顺序是怎么样的

如果域名处填写的分别为 IP 与域名&#xff0c;那么当使用进行 IP 请求时&#xff0c;则将会命中第一个配置的站点 以上图为例&#xff0c;如果用户使用 IP 访问&#xff0c;命中 example.com。 如果域名处填写的分别为域名与泛域名&#xff0c;除非准确命中域名&#xff0c;否…

关于写删除接口的一些理解

背景 在前两篇文章中&#xff0c;我讲了如何编写查询接口和新增接口。这篇文章将讲解如何编写删除接口。 “删除”接口的总体思路 一般情况下&#xff0c;删除接口的思路是通过记录的id来删除某一行。在实际工作中&#xff0c;我还没有遇到过使用其他字段来删除记录的情况&am…

TinTin Web3 动态精选:Vitalik 探讨以太坊协议,Solana ETN 开启质押功能

TinTin 快讯由 TinTinLand 开发者技术社区打造&#xff0c;旨在为开发者提供最新的 Web3 新闻、市场时讯和技术更新。TinTin 快讯将以周为单位&#xff0c; 汇集当周内的行业热点并以快讯的形式排列成文。掌握一手的技术资讯和市场动态&#xff0c;将有助于 TinTinLand 社区的开…

Unity-Editor扩展,引擎管理AudioClip,音乐音效快捷播放功能

目录 选择一个Audio 音频文件即会 关键在于三个快捷模式 播放&#xff0c; 自动播放 循环播放 根本不需要Editor扩展开发 没找到虚幻引擎的audio 的管理是怎么样的 参考&#xff1a; 本来&#xff0c;觉得没有快捷方式&#xff0c;播放很不爽 想自定义搞一个&#xff…

win10怎么卸载软件干净?电脑彻底删除软件的方法介绍,一键清理卸载残留!

电脑上经常会下载各种各样的软件来协助我们办公&#xff0c;不同的软件能够满足不同的需求。 但是不少软件可能使用频率没有那么高&#xff0c;甚至完全不使用。这个时候就需要将这些不常用的电脑软件卸载掉了&#xff0c;卸载软件能够释放一定的存储空间&#xff0c;提高电脑…

【WebSocket实战】——创建项目初始架构

这一篇文章主要是为了介绍如何在visual中创建一个项目并服务于我们要做的websockt项目&#xff0c;所以这里如果已经懂得的人&#xff0c;可以直接跳过。 目录 1&#xff09;创建空白解决方案 2&#xff09;创建asp.NET Core项目 3&#xff09;创建winform项目作为客户端1 …

纳斯达克大屏投放:为什么越来越多的企业要投放纳斯达克户外广告

纳斯达克大屏投放&#xff1a;为什么越来越多的企业要投放纳斯达克户外广告 一、纳斯达克户外大屏的独特魅力 在全球商业的舞台上&#xff0c;纳斯达克户外大屏以其无与伦比的影响力和曝光度&#xff0c;成为众多企业竞相追逐的广告投放目标。为什么越来越多的企业选择在纳斯…

react18中的函数组件底层渲染原理分析

react 中的函数组件底层渲染原理 react组件没有局部与全局之分&#xff0c;它是一个整体。这点跟vue的组件化是不同的。要实现 react 中的全局组件&#xff0c;可以将组件挂在react上&#xff0c;这样只要引入了react&#xff0c;就可以直接使用该组件。 函数式组件的创建 …