DenseNet算法:口腔癌识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一 DenseNet算法结构

其基本思路与ResNet一致,但是它建立的是前面所有层和后面层的密集连接,它的另一大特色是通过特征在channel上的连接来实现特征重用。 

二 设计理念 

 

三 结构 

四 算法代码 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlib,randomdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
data_dir = './data/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNamesimport matplotlib.pyplot as plt
from PIL import Image# 指定图像文件夹路径
image_folder = './data/OSCC/'# 获取文件夹中的所有图像文件
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]# 创建Matplotlib图像
fig, axes = plt.subplots(3, 8, figsize=(16, 6))# 使用列表推导式加载和显示图像
for ax, img_file in zip(axes.flat, image_files):img_path = os.path.join(image_folder, img_file)img = Image.open(img_path)ax.imshow(img)ax.axis('off')# 显示图像
plt.tight_layout()
plt.show()

total_datadir = './data/'# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data# 划分训练集
train_size = int(0.7 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_datasetbatch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break
Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch
from torch.utils import model_zoo
from torchvision.models.video.resnet import model_urls'''
_DenseLayer 类实现了 DenseNet 的关键机制:
通过使用批归一化、ReLU 激活和卷积层来提取特征,并通过密集连接促进特征的共享和再利用。
'''
class _DenseLayer(nn.Sequential):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):''':param num_input_features: 输入特征数:param growth_rate: 每层增长的特征数:param bn_size: 批归一化层的大小:param drop_rate: 丢弃率'''super(_DenseLayer, self).__init__()# 添加一个批归一化层(BatchNorm2d),用于对输入特征进行标准化self.add_module("norm1", nn.BatchNorm2d(num_input_features))# 添加一个 ReLU 激活函数self.add_module("relu1", nn.ReLU(inplace=True))# 添加第一个卷积层(Conv2d),其输入通道数为 num_input_features,输出通道数为 bn_size * growth_rate。# 这里使用 1x1 卷积,主要用于减少特征图的维度,并引入更多特征self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate,kernel_size=1, stride=1, bias=False))# 添加第二个批归一化层,应用于第一个卷积层的输出self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))# 添加第二个 ReLU 激活函数。与第一个激活函数相同,提供非线性变换self.add_module("relu2", nn.ReLU(inplace=True))# 添加第二个卷积层,输入通道数为 bn_size * growth_rate,输出通道数为 growth_rate。# 这里使用 3x3 卷积,通常用于提取更复杂的特征self.add_module("conv2", nn.Conv2d(bn_size * growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False))# 保存丢弃率(drop rate),用于在前向传播中进行 dropout 操作,以防止过拟合self.drop_rate = drop_ratedef forward(self, x):# 调用父类 nn.Sequential 的 forward 方法,将输入 x 传递给之前添加的所有层。# 输出 new_features 是经过所有层处理后的特征new_features = super(_DenseLayer, self).forward(x)# 检查丢弃率是否大于 0,如果是,则进行 dropout 操作if self.drop_rate > 0:# 对新特征应用 dropout,p 是丢弃概率,training 参数指示当前是否在训练模式。这将随机将一部分特征置为零,从而帮助减少过拟合new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)# 将输入 x 和新特征 new_features 在通道维度(即维度 1)上连接。这样可以实现密集连接,允许模型利用前面层的所有特征return torch.cat([x, new_features], 1)'''
创建一个包含多个密集层的模块,每个层都会根据前面层的输出特征动态调整输入特征数量,形成一个密集连接的网络结构。
'''
class _DenseBlock(nn.Sequential):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):'''num_layers: 该密集块中层的数量。num_input_features: 输入特征的数量。bn_size: 批量归一化的大小。growth_rate: 每层输出特征的增长率。drop_rate: dropout 率,用于防止过拟合'''super(_DenseBlock, self).__init__()# 开始一个循环,迭代 num_layers 次,为每一层创建一个密集层for i in range(num_layers):# 在每次迭代中,创建一个新的 _DenseLayer 实例。该层的输入特征数量为 num_input_features + i * growth_rate,即前面所有层的输出特征总和layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)# 将创建的密集层添加到模块中,并命名为 denselayer1、denselayer2,依此类推。这样可以方便后续访问和调试self.add_module("denselayer%d" % (i + 1,), layer)'''
构建神经网络的一个过渡层,在神经网络中通常用于特征的转换和下采样
'''
class _Transition(nn.Sequential):def __init__(self,num_input_feature,num_output_features):super(_Transition,self).__init__()# 添加一个批归一化层,标准化输入特征self.add_module("norm",nn.BatchNorm2d(num_input_feature))# 添加一个 ReLU 激活函数self.add_module("relu",nn.ReLU(inplace=True))# 添加一个卷积层,使用 1x1 的卷积核,连接输入特征和输出特征。self.add_module("conv",nn.Conv2d(num_input_feature,num_output_features,kernel_size=1,stride=1,bias=False))# 添加一个 2x2 的平均池化层,步幅为 2,用于减少特征图的大小self.add_module("pool",nn.AvgPool2d(2,stride=2))class DenseNet(nn.Module):def __init__(self,growth_rate=32,block_config=(6,12,24,16),num_init_features=64,bn_size=4,compression_rate=0.5,drop_rate=0,num_classes=1000):'''growth_rate: 每个DenseBlock中每层输出特征图的增长率。block_config: 一个元组,指定每个DenseBlock中的层数。num_init_features: 第一层卷积的输出特征数量。bn_size: Batch Normalization的大小compression_rate: 每个Transition层中输出特征数量的压缩比例。drop_rate: Dropout的概率num_classes: 最终分类的类别数。'''super(DenseNet,self).__init__()# 第一层卷积self.features = nn.Sequential(OrderedDict([("conv0",nn.Conv2d(3,num_init_features,kernel_size=7,stride=2,padding=3,bias=False)),("norm0",nn.BatchNorm2d(num_init_features)),("relu0",nn.ReLU(inplace=True)),("pool0",nn.MaxPool2d(3,stride=2,padding=1))]))# DenseBlocknum_features = num_init_features# 遍历block_config,为每个DenseBlock构建模型for i,num_layers in enumerate(block_config):block = _DenseBlock(num_layers,num_features,bn_size,growth_rate,drop_rate)self.features.add_module("denseblock%d"%(i+1),block)# 更新当前特征数量,每个DenseBlock后增加num_layers * growth_ratenum_features += num_layers*growth_rateif i != len(block_config) - 1:# 定义Transition层,连接DenseBlock,减小特征图尺寸(通过compression_ratetransition = _Transition(num_features,int(num_features*compression_rate))# 将DenseBlock和Transition层添加到模型中self.features.add_module("transition%d"%(i+1),transition)num_features = int(num_features * compression_rate)# final bn+relu# 在所有DenseBlock和Transition层之后,添加一个Batch Normalization层和ReLU激活层self.features.add_module("norm5",nn.BatchNorm2d(num_features))self.features.add_module("relu5",nn.ReLU(inplace=True))# classification layer# 定义全连接层,将特征映射到类别数self.classifier = nn.Linear(num_features,num_classes)# 参数初始化'''遍历所有模块,初始化权重。卷积层: 使用Kaiming正态分布初始化。BatchNorm层: 将偏置初始化为0,权重初始化为1。全连接层: 将偏置初始化为0。'''for m in self.modules():if isinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m,nn.BatchNorm2d):nn.init.constant_(m.bias,0)nn.init.constant_(m.weight,1)elif isinstance(m,nn.Linear):nn.init.constant_(m.bias,0)def forward(self,x):'''self.features(x): 将输入x传递通过所有特征层。F.avg_pool2d: 在特征图上进行全局平均池化。view(features.size(0), -1): 将池化后的特征展平。self.classifier(out): 通过分类层得到输出。return out: 返回最终的分类结果。'''features = self.features(x)out = F.avg_pool2d(features,7,stride=1).view(features.size(0),-1)out = self.classifier(out)return outdef densetnet121(pretrained=False, **kwargs):model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=len(classeNames))if pretrained:pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')# 从指定的 URL 加载 DenseNet-121 的预训练权重,存储在 state_dictstate_dict = model_zoo.load_url(model_urls['densenet121'])for key in list(state_dict.keys()):res = pattern.match(key)if res:# 创建一个新键,组合匹配结果的前半部分和后半部分new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]# 将处理后的权重加载到模型中model.load_state_dict(state_dict)return modelmodel = densetnet121()
model
import torchsummary as summary
summary.summary(model,(3,224,224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0BatchNorm2d-5           [-1, 64, 56, 56]             128ReLU-6           [-1, 64, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10           [-1, 32, 56, 56]          36,864BatchNorm2d-11           [-1, 96, 56, 56]             192ReLU-12           [-1, 96, 56, 56]               0Conv2d-13          [-1, 128, 56, 56]          12,288BatchNorm2d-14          [-1, 128, 56, 56]             256ReLU-15          [-1, 128, 56, 56]               0Conv2d-16           [-1, 32, 56, 56]          36,864BatchNorm2d-17          [-1, 128, 56, 56]             256ReLU-18          [-1, 128, 56, 56]               0Conv2d-19          [-1, 128, 56, 56]          16,384BatchNorm2d-20          [-1, 128, 56, 56]             256ReLU-21          [-1, 128, 56, 56]               0Conv2d-22           [-1, 32, 56, 56]          36,864BatchNorm2d-23          [-1, 160, 56, 56]             320ReLU-24          [-1, 160, 56, 56]               0Conv2d-25          [-1, 128, 56, 56]          20,480BatchNorm2d-26          [-1, 128, 56, 56]             256ReLU-27          [-1, 128, 56, 56]               0Conv2d-28           [-1, 32, 56, 56]          36,864BatchNorm2d-29          [-1, 192, 56, 56]             384ReLU-30          [-1, 192, 56, 56]               0Conv2d-31          [-1, 128, 56, 56]          24,576BatchNorm2d-32          [-1, 128, 56, 56]             256ReLU-33          [-1, 128, 56, 56]               0Conv2d-34           [-1, 32, 56, 56]          36,864BatchNorm2d-35          [-1, 224, 56, 56]             448ReLU-36          [-1, 224, 56, 56]               0Conv2d-37          [-1, 128, 56, 56]          28,672BatchNorm2d-38          [-1, 128, 56, 56]             256ReLU-39          [-1, 128, 56, 56]               0Conv2d-40           [-1, 32, 56, 56]          36,864BatchNorm2d-41          [-1, 256, 56, 56]             512ReLU-42          [-1, 256, 56, 56]               0Conv2d-43          [-1, 128, 56, 56]          32,768AvgPool2d-44          [-1, 128, 28, 28]               0BatchNorm2d-45          [-1, 128, 28, 28]             256ReLU-46          [-1, 128, 28, 28]               0Conv2d-47          [-1, 128, 28, 28]          16,384BatchNorm2d-48          [-1, 128, 28, 28]             256ReLU-49          [-1, 128, 28, 28]               0Conv2d-50           [-1, 32, 28, 28]          36,864BatchNorm2d-51          [-1, 160, 28, 28]             320ReLU-52          [-1, 160, 28, 28]               0Conv2d-53          [-1, 128, 28, 28]          20,480BatchNorm2d-54          [-1, 128, 28, 28]             256ReLU-55          [-1, 128, 28, 28]               0Conv2d-56           [-1, 32, 28, 28]          36,864BatchNorm2d-57          [-1, 192, 28, 28]             384ReLU-58          [-1, 192, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          24,576BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62           [-1, 32, 28, 28]          36,864BatchNorm2d-63          [-1, 224, 28, 28]             448ReLU-64          [-1, 224, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]          28,672BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68           [-1, 32, 28, 28]          36,864BatchNorm2d-69          [-1, 256, 28, 28]             512ReLU-70          [-1, 256, 28, 28]               0Conv2d-71          [-1, 128, 28, 28]          32,768BatchNorm2d-72          [-1, 128, 28, 28]             256ReLU-73          [-1, 128, 28, 28]               0Conv2d-74           [-1, 32, 28, 28]          36,864BatchNorm2d-75          [-1, 288, 28, 28]             576ReLU-76          [-1, 288, 28, 28]               0Conv2d-77          [-1, 128, 28, 28]          36,864BatchNorm2d-78          [-1, 128, 28, 28]             256ReLU-79          [-1, 128, 28, 28]               0Conv2d-80           [-1, 32, 28, 28]          36,864BatchNorm2d-81          [-1, 320, 28, 28]             640ReLU-82          [-1, 320, 28, 28]               0Conv2d-83          [-1, 128, 28, 28]          40,960BatchNorm2d-84          [-1, 128, 28, 28]             256ReLU-85          [-1, 128, 28, 28]               0Conv2d-86           [-1, 32, 28, 28]          36,864BatchNorm2d-87          [-1, 352, 28, 28]             704ReLU-88          [-1, 352, 28, 28]               0Conv2d-89          [-1, 128, 28, 28]          45,056BatchNorm2d-90          [-1, 128, 28, 28]             256ReLU-91          [-1, 128, 28, 28]               0Conv2d-92           [-1, 32, 28, 28]          36,864BatchNorm2d-93          [-1, 384, 28, 28]             768ReLU-94          [-1, 384, 28, 28]               0Conv2d-95          [-1, 128, 28, 28]          49,152BatchNorm2d-96          [-1, 128, 28, 28]             256ReLU-97          [-1, 128, 28, 28]               0Conv2d-98           [-1, 32, 28, 28]          36,864BatchNorm2d-99          [-1, 416, 28, 28]             832ReLU-100          [-1, 416, 28, 28]               0Conv2d-101          [-1, 128, 28, 28]          53,248BatchNorm2d-102          [-1, 128, 28, 28]             256ReLU-103          [-1, 128, 28, 28]               0Conv2d-104           [-1, 32, 28, 28]          36,864BatchNorm2d-105          [-1, 448, 28, 28]             896ReLU-106          [-1, 448, 28, 28]               0Conv2d-107          [-1, 128, 28, 28]          57,344BatchNorm2d-108          [-1, 128, 28, 28]             256ReLU-109          [-1, 128, 28, 28]               0Conv2d-110           [-1, 32, 28, 28]          36,864BatchNorm2d-111          [-1, 480, 28, 28]             960ReLU-112          [-1, 480, 28, 28]               0Conv2d-113          [-1, 128, 28, 28]          61,440BatchNorm2d-114          [-1, 128, 28, 28]             256ReLU-115          [-1, 128, 28, 28]               0Conv2d-116           [-1, 32, 28, 28]          36,864BatchNorm2d-117          [-1, 512, 28, 28]           1,024ReLU-118          [-1, 512, 28, 28]               0Conv2d-119          [-1, 256, 28, 28]         131,072AvgPool2d-120          [-1, 256, 14, 14]               0BatchNorm2d-121          [-1, 256, 14, 14]             512ReLU-122          [-1, 256, 14, 14]               0Conv2d-123          [-1, 128, 14, 14]          32,768BatchNorm2d-124          [-1, 128, 14, 14]             256ReLU-125          [-1, 128, 14, 14]               0Conv2d-126           [-1, 32, 14, 14]          36,864BatchNorm2d-127          [-1, 288, 14, 14]             576ReLU-128          [-1, 288, 14, 14]               0Conv2d-129          [-1, 128, 14, 14]          36,864BatchNorm2d-130          [-1, 128, 14, 14]             256ReLU-131          [-1, 128, 14, 14]               0Conv2d-132           [-1, 32, 14, 14]          36,864BatchNorm2d-133          [-1, 320, 14, 14]             640ReLU-134          [-1, 320, 14, 14]               0Conv2d-135          [-1, 128, 14, 14]          40,960BatchNorm2d-136          [-1, 128, 14, 14]             256ReLU-137          [-1, 128, 14, 14]               0Conv2d-138           [-1, 32, 14, 14]          36,864BatchNorm2d-139          [-1, 352, 14, 14]             704ReLU-140          [-1, 352, 14, 14]               0Conv2d-141          [-1, 128, 14, 14]          45,056BatchNorm2d-142          [-1, 128, 14, 14]             256ReLU-143          [-1, 128, 14, 14]               0Conv2d-144           [-1, 32, 14, 14]          36,864BatchNorm2d-145          [-1, 384, 14, 14]             768ReLU-146          [-1, 384, 14, 14]               0Conv2d-147          [-1, 128, 14, 14]          49,152BatchNorm2d-148          [-1, 128, 14, 14]             256ReLU-149          [-1, 128, 14, 14]               0Conv2d-150           [-1, 32, 14, 14]          36,864BatchNorm2d-151          [-1, 416, 14, 14]             832ReLU-152          [-1, 416, 14, 14]               0Conv2d-153          [-1, 128, 14, 14]          53,248BatchNorm2d-154          [-1, 128, 14, 14]             256ReLU-155          [-1, 128, 14, 14]               0Conv2d-156           [-1, 32, 14, 14]          36,864BatchNorm2d-157          [-1, 448, 14, 14]             896ReLU-158          [-1, 448, 14, 14]               0Conv2d-159          [-1, 128, 14, 14]          57,344BatchNorm2d-160          [-1, 128, 14, 14]             256ReLU-161          [-1, 128, 14, 14]               0Conv2d-162           [-1, 32, 14, 14]          36,864BatchNorm2d-163          [-1, 480, 14, 14]             960ReLU-164          [-1, 480, 14, 14]               0Conv2d-165          [-1, 128, 14, 14]          61,440BatchNorm2d-166          [-1, 128, 14, 14]             256ReLU-167          [-1, 128, 14, 14]               0Conv2d-168           [-1, 32, 14, 14]          36,864BatchNorm2d-169          [-1, 512, 14, 14]           1,024ReLU-170          [-1, 512, 14, 14]               0Conv2d-171          [-1, 128, 14, 14]          65,536BatchNorm2d-172          [-1, 128, 14, 14]             256ReLU-173          [-1, 128, 14, 14]               0Conv2d-174           [-1, 32, 14, 14]          36,864BatchNorm2d-175          [-1, 544, 14, 14]           1,088ReLU-176          [-1, 544, 14, 14]               0Conv2d-177          [-1, 128, 14, 14]          69,632BatchNorm2d-178          [-1, 128, 14, 14]             256ReLU-179          [-1, 128, 14, 14]               0Conv2d-180           [-1, 32, 14, 14]          36,864BatchNorm2d-181          [-1, 576, 14, 14]           1,152ReLU-182          [-1, 576, 14, 14]               0Conv2d-183          [-1, 128, 14, 14]          73,728BatchNorm2d-184          [-1, 128, 14, 14]             256ReLU-185          [-1, 128, 14, 14]               0Conv2d-186           [-1, 32, 14, 14]          36,864BatchNorm2d-187          [-1, 608, 14, 14]           1,216ReLU-188          [-1, 608, 14, 14]               0Conv2d-189          [-1, 128, 14, 14]          77,824BatchNorm2d-190          [-1, 128, 14, 14]             256ReLU-191          [-1, 128, 14, 14]               0Conv2d-192           [-1, 32, 14, 14]          36,864BatchNorm2d-193          [-1, 640, 14, 14]           1,280ReLU-194          [-1, 640, 14, 14]               0Conv2d-195          [-1, 128, 14, 14]          81,920BatchNorm2d-196          [-1, 128, 14, 14]             256ReLU-197          [-1, 128, 14, 14]               0Conv2d-198           [-1, 32, 14, 14]          36,864BatchNorm2d-199          [-1, 672, 14, 14]           1,344ReLU-200          [-1, 672, 14, 14]               0Conv2d-201          [-1, 128, 14, 14]          86,016BatchNorm2d-202          [-1, 128, 14, 14]             256ReLU-203          [-1, 128, 14, 14]               0Conv2d-204           [-1, 32, 14, 14]          36,864BatchNorm2d-205          [-1, 704, 14, 14]           1,408ReLU-206          [-1, 704, 14, 14]               0Conv2d-207          [-1, 128, 14, 14]          90,112BatchNorm2d-208          [-1, 128, 14, 14]             256ReLU-209          [-1, 128, 14, 14]               0Conv2d-210           [-1, 32, 14, 14]          36,864BatchNorm2d-211          [-1, 736, 14, 14]           1,472ReLU-212          [-1, 736, 14, 14]               0Conv2d-213          [-1, 128, 14, 14]          94,208BatchNorm2d-214          [-1, 128, 14, 14]             256ReLU-215          [-1, 128, 14, 14]               0Conv2d-216           [-1, 32, 14, 14]          36,864BatchNorm2d-217          [-1, 768, 14, 14]           1,536ReLU-218          [-1, 768, 14, 14]               0Conv2d-219          [-1, 128, 14, 14]          98,304BatchNorm2d-220          [-1, 128, 14, 14]             256ReLU-221          [-1, 128, 14, 14]               0Conv2d-222           [-1, 32, 14, 14]          36,864BatchNorm2d-223          [-1, 800, 14, 14]           1,600ReLU-224          [-1, 800, 14, 14]               0Conv2d-225          [-1, 128, 14, 14]         102,400BatchNorm2d-226          [-1, 128, 14, 14]             256ReLU-227          [-1, 128, 14, 14]               0Conv2d-228           [-1, 32, 14, 14]          36,864BatchNorm2d-229          [-1, 832, 14, 14]           1,664ReLU-230          [-1, 832, 14, 14]               0Conv2d-231          [-1, 128, 14, 14]         106,496BatchNorm2d-232          [-1, 128, 14, 14]             256ReLU-233          [-1, 128, 14, 14]               0Conv2d-234           [-1, 32, 14, 14]          36,864BatchNorm2d-235          [-1, 864, 14, 14]           1,728ReLU-236          [-1, 864, 14, 14]               0Conv2d-237          [-1, 128, 14, 14]         110,592BatchNorm2d-238          [-1, 128, 14, 14]             256ReLU-239          [-1, 128, 14, 14]               0Conv2d-240           [-1, 32, 14, 14]          36,864BatchNorm2d-241          [-1, 896, 14, 14]           1,792ReLU-242          [-1, 896, 14, 14]               0Conv2d-243          [-1, 128, 14, 14]         114,688BatchNorm2d-244          [-1, 128, 14, 14]             256ReLU-245          [-1, 128, 14, 14]               0Conv2d-246           [-1, 32, 14, 14]          36,864BatchNorm2d-247          [-1, 928, 14, 14]           1,856ReLU-248          [-1, 928, 14, 14]               0Conv2d-249          [-1, 128, 14, 14]         118,784BatchNorm2d-250          [-1, 128, 14, 14]             256ReLU-251          [-1, 128, 14, 14]               0Conv2d-252           [-1, 32, 14, 14]          36,864BatchNorm2d-253          [-1, 960, 14, 14]           1,920ReLU-254          [-1, 960, 14, 14]               0Conv2d-255          [-1, 128, 14, 14]         122,880BatchNorm2d-256          [-1, 128, 14, 14]             256ReLU-257          [-1, 128, 14, 14]               0Conv2d-258           [-1, 32, 14, 14]          36,864BatchNorm2d-259          [-1, 992, 14, 14]           1,984ReLU-260          [-1, 992, 14, 14]               0Conv2d-261          [-1, 128, 14, 14]         126,976BatchNorm2d-262          [-1, 128, 14, 14]             256ReLU-263          [-1, 128, 14, 14]               0Conv2d-264           [-1, 32, 14, 14]          36,864BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Conv2d-267          [-1, 512, 14, 14]         524,288AvgPool2d-268            [-1, 512, 7, 7]               0BatchNorm2d-269            [-1, 512, 7, 7]           1,024ReLU-270            [-1, 512, 7, 7]               0Conv2d-271            [-1, 128, 7, 7]          65,536BatchNorm2d-272            [-1, 128, 7, 7]             256ReLU-273            [-1, 128, 7, 7]               0Conv2d-274             [-1, 32, 7, 7]          36,864BatchNorm2d-275            [-1, 544, 7, 7]           1,088ReLU-276            [-1, 544, 7, 7]               0Conv2d-277            [-1, 128, 7, 7]          69,632BatchNorm2d-278            [-1, 128, 7, 7]             256ReLU-279            [-1, 128, 7, 7]               0Conv2d-280             [-1, 32, 7, 7]          36,864BatchNorm2d-281            [-1, 576, 7, 7]           1,152ReLU-282            [-1, 576, 7, 7]               0Conv2d-283            [-1, 128, 7, 7]          73,728BatchNorm2d-284            [-1, 128, 7, 7]             256ReLU-285            [-1, 128, 7, 7]               0Conv2d-286             [-1, 32, 7, 7]          36,864BatchNorm2d-287            [-1, 608, 7, 7]           1,216ReLU-288            [-1, 608, 7, 7]               0Conv2d-289            [-1, 128, 7, 7]          77,824BatchNorm2d-290            [-1, 128, 7, 7]             256ReLU-291            [-1, 128, 7, 7]               0Conv2d-292             [-1, 32, 7, 7]          36,864BatchNorm2d-293            [-1, 640, 7, 7]           1,280ReLU-294            [-1, 640, 7, 7]               0Conv2d-295            [-1, 128, 7, 7]          81,920BatchNorm2d-296            [-1, 128, 7, 7]             256ReLU-297            [-1, 128, 7, 7]               0Conv2d-298             [-1, 32, 7, 7]          36,864BatchNorm2d-299            [-1, 672, 7, 7]           1,344ReLU-300            [-1, 672, 7, 7]               0Conv2d-301            [-1, 128, 7, 7]          86,016BatchNorm2d-302            [-1, 128, 7, 7]             256ReLU-303            [-1, 128, 7, 7]               0Conv2d-304             [-1, 32, 7, 7]          36,864BatchNorm2d-305            [-1, 704, 7, 7]           1,408ReLU-306            [-1, 704, 7, 7]               0Conv2d-307            [-1, 128, 7, 7]          90,112BatchNorm2d-308            [-1, 128, 7, 7]             256ReLU-309            [-1, 128, 7, 7]               0Conv2d-310             [-1, 32, 7, 7]          36,864BatchNorm2d-311            [-1, 736, 7, 7]           1,472ReLU-312            [-1, 736, 7, 7]               0Conv2d-313            [-1, 128, 7, 7]          94,208BatchNorm2d-314            [-1, 128, 7, 7]             256ReLU-315            [-1, 128, 7, 7]               0Conv2d-316             [-1, 32, 7, 7]          36,864BatchNorm2d-317            [-1, 768, 7, 7]           1,536ReLU-318            [-1, 768, 7, 7]               0Conv2d-319            [-1, 128, 7, 7]          98,304BatchNorm2d-320            [-1, 128, 7, 7]             256ReLU-321            [-1, 128, 7, 7]               0Conv2d-322             [-1, 32, 7, 7]          36,864BatchNorm2d-323            [-1, 800, 7, 7]           1,600ReLU-324            [-1, 800, 7, 7]               0Conv2d-325            [-1, 128, 7, 7]         102,400BatchNorm2d-326            [-1, 128, 7, 7]             256ReLU-327            [-1, 128, 7, 7]               0Conv2d-328             [-1, 32, 7, 7]          36,864BatchNorm2d-329            [-1, 832, 7, 7]           1,664ReLU-330            [-1, 832, 7, 7]               0Conv2d-331            [-1, 128, 7, 7]         106,496BatchNorm2d-332            [-1, 128, 7, 7]             256ReLU-333            [-1, 128, 7, 7]               0Conv2d-334             [-1, 32, 7, 7]          36,864BatchNorm2d-335            [-1, 864, 7, 7]           1,728ReLU-336            [-1, 864, 7, 7]               0Conv2d-337            [-1, 128, 7, 7]         110,592BatchNorm2d-338            [-1, 128, 7, 7]             256ReLU-339            [-1, 128, 7, 7]               0Conv2d-340             [-1, 32, 7, 7]          36,864BatchNorm2d-341            [-1, 896, 7, 7]           1,792ReLU-342            [-1, 896, 7, 7]               0Conv2d-343            [-1, 128, 7, 7]         114,688BatchNorm2d-344            [-1, 128, 7, 7]             256ReLU-345            [-1, 128, 7, 7]               0Conv2d-346             [-1, 32, 7, 7]          36,864BatchNorm2d-347            [-1, 928, 7, 7]           1,856ReLU-348            [-1, 928, 7, 7]               0Conv2d-349            [-1, 128, 7, 7]         118,784BatchNorm2d-350            [-1, 128, 7, 7]             256ReLU-351            [-1, 128, 7, 7]               0Conv2d-352             [-1, 32, 7, 7]          36,864BatchNorm2d-353            [-1, 960, 7, 7]           1,920ReLU-354            [-1, 960, 7, 7]               0Conv2d-355            [-1, 128, 7, 7]         122,880BatchNorm2d-356            [-1, 128, 7, 7]             256ReLU-357            [-1, 128, 7, 7]               0Conv2d-358             [-1, 32, 7, 7]          36,864BatchNorm2d-359            [-1, 992, 7, 7]           1,984ReLU-360            [-1, 992, 7, 7]               0Conv2d-361            [-1, 128, 7, 7]         126,976BatchNorm2d-362            [-1, 128, 7, 7]             256ReLU-363            [-1, 128, 7, 7]               0Conv2d-364             [-1, 32, 7, 7]          36,864BatchNorm2d-365           [-1, 1024, 7, 7]           2,048ReLU-366           [-1, 1024, 7, 7]               0Linear-367                    [-1, 2]           2,050
================================================================
Total params: 6,955,906
Trainable params: 6,955,906
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.57
Params size (MB): 26.53
Estimated Total Size (MB): 321.68
----------------------------------------------------------------
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_lossepochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:47.7%, Train_loss:0.725, Test_acc:47.4%,Test_loss:0.708
Epoch: 2, Train_acc:50.2%, Train_loss:0.697, Test_acc:52.7%,Test_loss:0.690
Epoch: 3, Train_acc:56.1%, Train_loss:0.686, Test_acc:59.9%,Test_loss:0.681
Epoch: 4, Train_acc:58.5%, Train_loss:0.679, Test_acc:60.7%,Test_loss:0.675
Epoch: 5, Train_acc:60.9%, Train_loss:0.673, Test_acc:60.1%,Test_loss:0.671
Epoch: 6, Train_acc:61.7%, Train_loss:0.670, Test_acc:62.6%,Test_loss:0.664
Epoch: 7, Train_acc:62.4%, Train_loss:0.665, Test_acc:63.5%,Test_loss:0.659
Epoch: 8, Train_acc:63.0%, Train_loss:0.660, Test_acc:64.8%,Test_loss:0.653
Epoch: 9, Train_acc:64.2%, Train_loss:0.656, Test_acc:65.5%,Test_loss:0.649
Epoch:10, Train_acc:64.9%, Train_loss:0.652, Test_acc:65.6%,Test_loss:0.644
Epoch:11, Train_acc:65.4%, Train_loss:0.649, Test_acc:66.6%,Test_loss:0.641
Epoch:12, Train_acc:65.0%, Train_loss:0.646, Test_acc:66.6%,Test_loss:0.638
Epoch:13, Train_acc:64.8%, Train_loss:0.643, Test_acc:67.5%,Test_loss:0.634
Epoch:14, Train_acc:65.7%, Train_loss:0.641, Test_acc:67.3%,Test_loss:0.633
Epoch:15, Train_acc:65.9%, Train_loss:0.638, Test_acc:67.8%,Test_loss:0.629
Epoch:16, Train_acc:66.3%, Train_loss:0.635, Test_acc:67.6%,Test_loss:0.626
Epoch:17, Train_acc:67.3%, Train_loss:0.632, Test_acc:67.8%,Test_loss:0.624
Epoch:18, Train_acc:67.1%, Train_loss:0.628, Test_acc:68.2%,Test_loss:0.618
Epoch:19, Train_acc:67.3%, Train_loss:0.628, Test_acc:68.9%,Test_loss:0.618
Epoch:20, Train_acc:67.9%, Train_loss:0.624, Test_acc:68.4%,Test_loss:0.614
Done
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/439369.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

成都跃享未来教育咨询有限公司抖音小店:引领教育咨询新风尚

在数字化浪潮席卷全球的今天,教育咨询行业正经历着前所未有的变革。成都跃享未来教育咨询有限公司,作为教育行业的一颗璀璨新星,凭借其前瞻性的教育理念与创新的运营模式,在抖音平台上开设了小店,不仅为广大学子及家长…

学籍管理平台|在线学籍管理平台系统|基于Springboot+VUE的在线学籍管理平台系统设计与实现(源码+数据库+文档)

在线学籍管理平台系统 目录 基于SpringbootVUE的在线学籍管理平台系统设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍:✌️大…

C++多态、虚函数以及抽象类

目录 1.多态的概念 2.多态的定义及实现 2.1多态的构成条件 2.1.1实现多态还有两个必要条件 2.1.2虚函数 2.1.3虚函数的重写/覆盖 2.1.4多态场景的题目 2.1.5虚函数重写的一些其他问题 2.1.5.1协变(了解) 2.1.5.2析构函数的重写 2.1.6override和final关键字 2.…

【Springer上传手稿记录】《Signal, Image and Video Processing》

Springer上传手稿记录 以signal,image and video proecessing为例上传手稿或图片时提示上传失败无法编译成pdf错误1:出现Unknown theoremstyle相关错误错误2:command Illegal错误3:command I found no style file 通用问题问题1&a…

CORE MVC 过滤器 (筛选器)《2》 TypeFilter、ServiceFilter

TypeFilter、ServiceFilter ServiceFilter vs TypeFilter ServiceFilter和TypeFilter都实现了IFilterFactory ServiceFilter需要对自定义的Filter进行注册,TypeFilter不需要 ServiceFilter的Filter生命周期源自于您如何注册(全局、区域)&…

【AI学习】Mamba学习(二):线性注意力

上一篇《Mamba学习(一):总体架构》提到,Transformer 模型的主要缺点是:自注意力机制的计算量会随着上下文长度的增加呈平方级增长。所以,许多次二次时间架构(指一个函数或算法的增长速度小于二次…

SpringBoot框架下校园资料库的构建与优化

1系统概述 1.1 研究背景 如今互联网高速发展,网络遍布全球,通过互联网发布的消息能快而方便的传播到世界每个角落,并且互联网上能传播的信息也很广,比如文字、图片、声音、视频等。从而,这种种好处使得互联网成了信息传…

10.5今日错题解析(软考)

目录 前言面向对象技术——设计模式的应用场景计算机组成与体系结构——逻辑运算 前言 这是用来记录我备考软考设计师的错题的,今天知识点为设计模式的应用场景、逻辑运算,大部分错题摘自希赛中的题目,但相关解析是原创,有自己的…

【Python】Dejavu:Python 音频指纹识别库详解

Dejavu 是一个基于 Python 实现的开源音频指纹识别库,主要用于音频文件的识别和匹配。它通过生成音频文件的唯一“指纹”并将其存储在数据库中,来实现音频的快速匹配。Dejavu 的主要应用场景包括识别音乐、歌曲匹配、版权管理等。 ⭕️宇宙起点 &#x1…

class 004 选择 冒泡 插入排序

我感觉这个真是没有什么好讲的, 这个是比较简单的, 感觉没有什么必要写一篇博客, 而且这个这么简单的排序问题肯定有人已经有写好的帖子了, 肯定写的比我好, 所以我推荐大家直接去看“左程云”老师的讲解就很好了, 一定是能看懂的, 要是用文字形式再写一遍, 反而有点画蛇添足了…

windows下安装rabbitMQ并开通管理界面和允许远程访问

如题,在windows下安装一个rabbitMQ server;然后用浏览器访问其管理界面;由于rabbitMQ的默认账号guest默认只能本机访问,因此需要设置允许其他机器远程访问。这跟mysql的思路很像,默认只能本地访问,要远程访…

Oracle架构之表空间详解

文章目录 1 表空间介绍1.1 简介1.2 表空间分类1.2.1 SYSTEM 表空间1.2.2 SYSAUX 表空间1.2.3 UNDO 表空间1.2.4 USERS 表空间 1.3 表空间字典与本地管理1.3.1 字典管理表空间(Dictionary Management Tablespace,DMT)1.3.2 本地管理方式的表空…

计算机网络(十) —— IP协议详解,理解运营商和全球网络

目录 一,关于IP 1.1 什么是IP协议 1.2 前置认识 二,IP报头字段详解 三,网段划分 3.1 IP地址的构成 3.2 网段划分 3.3 子网划分 3.4 IP地址不足问题 四,公网IP和私有IP 五,理解运营商和全球网络 六&#xff…

基于大数据技术的颈椎病预防交流与数据分析及可视化系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏:Java精选实战项目…

【优选算法】(第二十一篇)

目录 外观数列(medium) 题目解析 讲解算法原理 编写代码 数⻘蛙(medium) 题目解析 讲解算法原理 编写代码 外观数列(medium) 题目解析 1.题目链接:. - 力扣(LeetCode) 2.题目描述 给定⼀个正整数n&#xff0…

算法篇1:双指针思想的运用(1)--C++

一.算法解析 双指针,顾名思义就是两个指针,常见的算法中,我们可以看到两种: 1.对撞指针:一般用于顺序结构,也称为左右指针。 对撞指针从两端向中间移动。一个指针从最左端开始,另一个从最右端…

Yolov8轻量级网络改进GhostNet

1,理论部分 由于内存和计算资源有限,在移动设备上部署卷积神经网络 (CNN) 很困难。我们的目标是通过利用特征图中的冗余,为 CPU 和 GPU 等异构设备设计高效的神经网络,这在神经架构设计中很少被研究。对于类 CPU 设备,我们提出了一种新颖的 CPU 高效 Ghost (C-Ghost) …

Mysql:数据库和表增删查改基本语句

一、数据库操作 1)、数据库创建 创建数据库本质就是创建一个目录(ubuntu,创建的目录文件存放在/var/lib/mysql);后续创建表本质就是在该目录下创建文件(不同存储引擎,会创建的文件数目是不同的…

PASCAL VOC 2012数据集 20类物体,这些物体包括人、动物(如猫、狗、鸟等)、交通工具(如车、船、飞机等)以及家具(如椅子、桌子、沙发等)。

VOC2012数据集是PASCAL VOC挑战赛官方使用的数据集之一,主要包含20类物体,这些物体包括人、动物(如猫、狗、鸟等)、交通工具(如车、船、飞机等)以及家具(如椅子、桌子、沙发等)。每个…

计算机网络:物理层 —— 物理层下的传输媒体

文章目录 传输媒体导向性媒体同轴电缆双绞线光纤光纤分类中心波长光纤规格光纤的优缺点 非导向性媒体ISM 频段无线电波微波激光红外线可见光 传输媒体 传输媒体是计算机网络设备之间的物理通路,也称为传输介质或传输媒介,并不包含在计算机网络体系结构中…