深度学习Day-18:ResNet50V2算法实战与解析

 🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 根据本文Tensorflow代码,编写对应的Pytorch代码
  2. 了解ResNetV2与ResNetV的区别

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1.设置GPU

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import pathlib, warningswarnings.filterwarnings("ignore")  # 忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 导入数据

本项目所采用的数据集未收录于公开数据中,故需要自己在文件目录中导入相应数据集合,并设置对应文件目录,以供后续学习过程中使用。

运行下述代码:

data_dir = './data/bird_photos/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
print(classeNames)

得到如下输出:

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

接下来,我们通过transforms.Compose对整个数据集进行预处理:

train_transforms = transforms.Compose([transforms.Resize([224, 224]),      # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),              # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(               # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])      # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),      # 将输入图片resize成统一尺寸transforms.ToTensor(),              # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(               # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])      # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./data/bird_photos/", transform=train_transforms)
print(total_data.class_to_idx)

得到如下输出:

{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}

3. 划分数据集

 此处数据集需要做按比例划分的操作:

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

接下来,根据划分得到的训练集和验证集对数据集进行包装:

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)

并通过:

for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

输出测试数据集的数据分布情况:

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4.搭建模型

1.模型搭建

class Block2(nn.Module):def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):super(Block2, self).__init__()self.preact = nn.Sequential(nn.BatchNorm2d(in_channel),nn.ReLU(True))self.shortcut = conv_shortcutif self.shortcut:self.short = nn.Conv2d(in_channel, 4 * filters, 1, stride=stride, padding=0, bias=False)elif stride > 1:self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)else:self.short = nn.Identity()self.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),nn.BatchNorm2d(filters),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),nn.BatchNorm2d(filters),nn.ReLU(True))self.conv3 = nn.Conv2d(filters, 4 * filters, 1, stride=1, bias=False)def forward(self, x):x1 = self.preact(x)if self.shortcut:x2 = self.short(x1)else:x2 = self.short(x)x1 = self.conv1(x1)x1 = self.conv2(x1)x1 = self.conv3(x1)x = x1 + x2return xclass Stack2(nn.Module):def __init__(self, in_channel, filters, blocks, stride=2):super(Stack2, self).__init__()self.conv = nn.Sequential()self.conv.add_module(str(0), Block2(in_channel, filters, conv_shortcut=True))for i in range(1, blocks - 1):self.conv.add_module(str(i), Block2(4 * filters, filters))self.conv.add_module(str(blocks - 1), Block2(4 * filters, filters, stride=stride))def forward(self, x):x = self.conv(x)return xclass ResNet50V2(nn.Module):def __init__(self,include_top=True,  # 是否包含位于网络顶部的全链接层preact=True,  # 是否使用预激活use_bias=True,  # 是否对卷积层使用偏置input_shape=[224, 224, 3],classes=1000,pooling=None):  # 用于分类图像的可选类数super(ResNet50V2, self).__init__()self.conv1 = nn.Sequential()self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))if not preact:self.conv1.add_module('bn', nn.BatchNorm2d(64))self.conv1.add_module('relu', nn.ReLU())self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.conv2 = Stack2(64, 64, 3)self.conv3 = Stack2(256, 128, 4)self.conv4 = Stack2(512, 256, 6)self.conv5 = Stack2(1024, 512, 3, stride=1)self.post = nn.Sequential()if preact:self.post.add_module('bn', nn.BatchNorm2d(2048))self.post.add_module('relu', nn.ReLU())if include_top:self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))self.post.add_module('flatten', nn.Flatten())self.post.add_module('fc', nn.Linear(2048, classes))else:if pooling == 'avg':self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))elif pooling == 'max':self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.post(x)return xmodel = ResNet50V2().to(device)

2.查看模型信息

import torchsummary as summary
summary.summary(model, (3, 224, 224))

得到如下输出:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,472MaxPool2d-2           [-1, 64, 56, 56]               0BatchNorm2d-3           [-1, 64, 56, 56]             128ReLU-4           [-1, 64, 56, 56]               0Conv2d-5          [-1, 256, 56, 56]          16,384Conv2d-6           [-1, 64, 56, 56]           4,096BatchNorm2d-7           [-1, 64, 56, 56]             128ReLU-8           [-1, 64, 56, 56]               0Conv2d-9           [-1, 64, 56, 56]          36,864BatchNorm2d-10           [-1, 64, 56, 56]             128ReLU-11           [-1, 64, 56, 56]               0Conv2d-12          [-1, 256, 56, 56]          16,384Block2-13          [-1, 256, 56, 56]               0BatchNorm2d-14          [-1, 256, 56, 56]             512ReLU-15          [-1, 256, 56, 56]               0Identity-16          [-1, 256, 56, 56]               0Conv2d-17           [-1, 64, 56, 56]          16,384BatchNorm2d-18           [-1, 64, 56, 56]             128ReLU-19           [-1, 64, 56, 56]               0Conv2d-20           [-1, 64, 56, 56]          36,864BatchNorm2d-21           [-1, 64, 56, 56]             128ReLU-22           [-1, 64, 56, 56]               0Conv2d-23          [-1, 256, 56, 56]          16,384Block2-24          [-1, 256, 56, 56]               0BatchNorm2d-25          [-1, 256, 56, 56]             512ReLU-26          [-1, 256, 56, 56]               0MaxPool2d-27          [-1, 256, 28, 28]               0Conv2d-28           [-1, 64, 56, 56]          16,384BatchNorm2d-29           [-1, 64, 56, 56]             128ReLU-30           [-1, 64, 56, 56]               0Conv2d-31           [-1, 64, 28, 28]          36,864BatchNorm2d-32           [-1, 64, 28, 28]             128ReLU-33           [-1, 64, 28, 28]               0Conv2d-34          [-1, 256, 28, 28]          16,384Block2-35          [-1, 256, 28, 28]               0Stack2-36          [-1, 256, 28, 28]               0BatchNorm2d-37          [-1, 256, 28, 28]             512ReLU-38          [-1, 256, 28, 28]               0Conv2d-39          [-1, 512, 28, 28]         131,072Conv2d-40          [-1, 128, 28, 28]          32,768BatchNorm2d-41          [-1, 128, 28, 28]             256ReLU-42          [-1, 128, 28, 28]               0Conv2d-43          [-1, 128, 28, 28]         147,456BatchNorm2d-44          [-1, 128, 28, 28]             256ReLU-45          [-1, 128, 28, 28]               0Conv2d-46          [-1, 512, 28, 28]          65,536Block2-47          [-1, 512, 28, 28]               0BatchNorm2d-48          [-1, 512, 28, 28]           1,024ReLU-49          [-1, 512, 28, 28]               0Identity-50          [-1, 512, 28, 28]               0Conv2d-51          [-1, 128, 28, 28]          65,536BatchNorm2d-52          [-1, 128, 28, 28]             256ReLU-53          [-1, 128, 28, 28]               0Conv2d-54          [-1, 128, 28, 28]         147,456BatchNorm2d-55          [-1, 128, 28, 28]             256ReLU-56          [-1, 128, 28, 28]               0Conv2d-57          [-1, 512, 28, 28]          65,536Block2-58          [-1, 512, 28, 28]               0BatchNorm2d-59          [-1, 512, 28, 28]           1,024ReLU-60          [-1, 512, 28, 28]               0Identity-61          [-1, 512, 28, 28]               0Conv2d-62          [-1, 128, 28, 28]          65,536BatchNorm2d-63          [-1, 128, 28, 28]             256ReLU-64          [-1, 128, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]         147,456BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68          [-1, 512, 28, 28]          65,536Block2-69          [-1, 512, 28, 28]               0BatchNorm2d-70          [-1, 512, 28, 28]           1,024ReLU-71          [-1, 512, 28, 28]               0MaxPool2d-72          [-1, 512, 14, 14]               0Conv2d-73          [-1, 128, 28, 28]          65,536BatchNorm2d-74          [-1, 128, 28, 28]             256ReLU-75          [-1, 128, 28, 28]               0Conv2d-76          [-1, 128, 14, 14]         147,456BatchNorm2d-77          [-1, 128, 14, 14]             256ReLU-78          [-1, 128, 14, 14]               0Conv2d-79          [-1, 512, 14, 14]          65,536Block2-80          [-1, 512, 14, 14]               0Stack2-81          [-1, 512, 14, 14]               0BatchNorm2d-82          [-1, 512, 14, 14]           1,024ReLU-83          [-1, 512, 14, 14]               0Conv2d-84         [-1, 1024, 14, 14]         524,288Conv2d-85          [-1, 256, 14, 14]         131,072BatchNorm2d-86          [-1, 256, 14, 14]             512ReLU-87          [-1, 256, 14, 14]               0Conv2d-88          [-1, 256, 14, 14]         589,824BatchNorm2d-89          [-1, 256, 14, 14]             512ReLU-90          [-1, 256, 14, 14]               0Conv2d-91         [-1, 1024, 14, 14]         262,144Block2-92         [-1, 1024, 14, 14]               0BatchNorm2d-93         [-1, 1024, 14, 14]           2,048ReLU-94         [-1, 1024, 14, 14]               0Identity-95         [-1, 1024, 14, 14]               0Conv2d-96          [-1, 256, 14, 14]         262,144BatchNorm2d-97          [-1, 256, 14, 14]             512ReLU-98          [-1, 256, 14, 14]               0Conv2d-99          [-1, 256, 14, 14]         589,824BatchNorm2d-100          [-1, 256, 14, 14]             512ReLU-101          [-1, 256, 14, 14]               0Conv2d-102         [-1, 1024, 14, 14]         262,144Block2-103         [-1, 1024, 14, 14]               0BatchNorm2d-104         [-1, 1024, 14, 14]           2,048ReLU-105         [-1, 1024, 14, 14]               0Identity-106         [-1, 1024, 14, 14]               0Conv2d-107          [-1, 256, 14, 14]         262,144BatchNorm2d-108          [-1, 256, 14, 14]             512ReLU-109          [-1, 256, 14, 14]               0Conv2d-110          [-1, 256, 14, 14]         589,824BatchNorm2d-111          [-1, 256, 14, 14]             512ReLU-112          [-1, 256, 14, 14]               0Conv2d-113         [-1, 1024, 14, 14]         262,144Block2-114         [-1, 1024, 14, 14]               0BatchNorm2d-115         [-1, 1024, 14, 14]           2,048ReLU-116         [-1, 1024, 14, 14]               0Identity-117         [-1, 1024, 14, 14]               0Conv2d-118          [-1, 256, 14, 14]         262,144BatchNorm2d-119          [-1, 256, 14, 14]             512ReLU-120          [-1, 256, 14, 14]               0Conv2d-121          [-1, 256, 14, 14]         589,824BatchNorm2d-122          [-1, 256, 14, 14]             512ReLU-123          [-1, 256, 14, 14]               0Conv2d-124         [-1, 1024, 14, 14]         262,144Block2-125         [-1, 1024, 14, 14]               0BatchNorm2d-126         [-1, 1024, 14, 14]           2,048ReLU-127         [-1, 1024, 14, 14]               0Identity-128         [-1, 1024, 14, 14]               0Conv2d-129          [-1, 256, 14, 14]         262,144BatchNorm2d-130          [-1, 256, 14, 14]             512ReLU-131          [-1, 256, 14, 14]               0Conv2d-132          [-1, 256, 14, 14]         589,824BatchNorm2d-133          [-1, 256, 14, 14]             512ReLU-134          [-1, 256, 14, 14]               0Conv2d-135         [-1, 1024, 14, 14]         262,144Block2-136         [-1, 1024, 14, 14]               0BatchNorm2d-137         [-1, 1024, 14, 14]           2,048ReLU-138         [-1, 1024, 14, 14]               0MaxPool2d-139           [-1, 1024, 7, 7]               0Conv2d-140          [-1, 256, 14, 14]         262,144BatchNorm2d-141          [-1, 256, 14, 14]             512ReLU-142          [-1, 256, 14, 14]               0Conv2d-143            [-1, 256, 7, 7]         589,824BatchNorm2d-144            [-1, 256, 7, 7]             512ReLU-145            [-1, 256, 7, 7]               0Conv2d-146           [-1, 1024, 7, 7]         262,144Block2-147           [-1, 1024, 7, 7]               0Stack2-148           [-1, 1024, 7, 7]               0BatchNorm2d-149           [-1, 1024, 7, 7]           2,048ReLU-150           [-1, 1024, 7, 7]               0Conv2d-151           [-1, 2048, 7, 7]       2,097,152Conv2d-152            [-1, 512, 7, 7]         524,288BatchNorm2d-153            [-1, 512, 7, 7]           1,024ReLU-154            [-1, 512, 7, 7]               0Conv2d-155            [-1, 512, 7, 7]       2,359,296BatchNorm2d-156            [-1, 512, 7, 7]           1,024ReLU-157            [-1, 512, 7, 7]               0Conv2d-158           [-1, 2048, 7, 7]       1,048,576Block2-159           [-1, 2048, 7, 7]               0BatchNorm2d-160           [-1, 2048, 7, 7]           4,096ReLU-161           [-1, 2048, 7, 7]               0Identity-162           [-1, 2048, 7, 7]               0Conv2d-163            [-1, 512, 7, 7]       1,048,576BatchNorm2d-164            [-1, 512, 7, 7]           1,024ReLU-165            [-1, 512, 7, 7]               0Conv2d-166            [-1, 512, 7, 7]       2,359,296BatchNorm2d-167            [-1, 512, 7, 7]           1,024ReLU-168            [-1, 512, 7, 7]               0Conv2d-169           [-1, 2048, 7, 7]       1,048,576Block2-170           [-1, 2048, 7, 7]               0BatchNorm2d-171           [-1, 2048, 7, 7]           4,096ReLU-172           [-1, 2048, 7, 7]               0Identity-173           [-1, 2048, 7, 7]               0Conv2d-174            [-1, 512, 7, 7]       1,048,576BatchNorm2d-175            [-1, 512, 7, 7]           1,024ReLU-176            [-1, 512, 7, 7]               0Conv2d-177            [-1, 512, 7, 7]       2,359,296BatchNorm2d-178            [-1, 512, 7, 7]           1,024ReLU-179            [-1, 512, 7, 7]               0Conv2d-180           [-1, 2048, 7, 7]       1,048,576Block2-181           [-1, 2048, 7, 7]               0Stack2-182           [-1, 2048, 7, 7]               0BatchNorm2d-183           [-1, 2048, 7, 7]           4,096ReLU-184           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-185           [-1, 2048, 1, 1]               0Flatten-186                 [-1, 2048]               0Linear-187                 [-1, 1000]       2,049,000
================================================================
Total params: 25,549,416
Trainable params: 25,549,416
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 241.69
Params size (MB): 97.46
Estimated Total Size (MB): 339.73
----------------------------------------------------------------

三、 训练模型 

1. 编写训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0    # 初始化训练损失和正确率for X, y in dataloader:         # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)             # 网络输出loss = loss_fn(pred, y)     # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()       # grad属性归零loss.backward()             # 反向传播optimizer.step()            # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)   # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

3.正式训练

import copyoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数epochs = 10train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):# 更新学习率(使用自定义学习率时使用)# adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)print('Done')

得到如下输出:

Epoch: 1, Train_acc:38.3%, Train_loss:5.263, Test_acc:31.9%, Test_loss:5.151, Lr:1.00E-04
Epoch: 2, Train_acc:71.9%, Train_loss:1.759, Test_acc:31.0%, Test_loss:3.492, Lr:1.00E-04
Epoch: 3, Train_acc:82.7%, Train_loss:0.822, Test_acc:85.8%, Test_loss:0.620, Lr:1.00E-04
Epoch: 4, Train_acc:89.2%, Train_loss:0.478, Test_acc:83.2%, Test_loss:0.762, Lr:1.00E-04
Epoch: 5, Train_acc:89.2%, Train_loss:0.444, Test_acc:86.7%, Test_loss:0.629, Lr:1.00E-04
Epoch: 6, Train_acc:91.2%, Train_loss:0.359, Test_acc:73.5%, Test_loss:0.802, Lr:1.00E-04
Epoch: 7, Train_acc:95.1%, Train_loss:0.173, Test_acc:79.6%, Test_loss:0.689, Lr:1.00E-04
Epoch: 8, Train_acc:96.5%, Train_loss:0.141, Test_acc:80.5%, Test_loss:0.704, Lr:1.00E-04
Epoch: 9, Train_acc:98.5%, Train_loss:0.089, Test_acc:78.8%, Test_loss:0.879, Lr:1.00E-04
Epoch:10, Train_acc:95.8%, Train_loss:0.196, Test_acc:81.4%, Test_loss:0.718, Lr:1.00E-04
Done
预测结果是:Bananaquit
0.8672566371681416 0.5955437496304512
0.8672566371681416Process finished with exit code 0

四、 结果可视化

1. Loss&Accuracy

import matplotlib.pyplot as plt
# 隐藏警告
import warningswarnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

得到的可视化结果:

 2. 指定图片进行预测

首先,先定义出一个用于预测的函数:

from PIL import Imageclasses = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')

接着调用函数对指定图片进行预测:

predict_one_image(image_path='./data/bird_photos/Bananaquit/007.jpg',model=model,transform=train_transforms,classes=classes)

得到如下结果:

预测结果是:Bananaquit

3.模型评估

将模型调至评估模式:

best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(epoch_test_acc, epoch_test_loss)

得到如下输出:

0.8672566371681416 0.5955437496304512

 观察得到和前文中一致。

五、个人理解

除完成Tensorflow与pytorch之间的代码转换,还需了解ResNetV2及ResNetV之间的关系及区别:

首先,对比两个残差结构:

可以看出(a)结构先卷积后进行 BN 和激活函数计算,最后执行 addition 后再进行ReLU 计算; (b)结构先进行 BN 和激活函数计算后卷积,把 addition 后的 ReLU 计算放到了残差结构内部。

ResNetV2的最终确定经过了两轮尝试:

5.1关于残差结构的尝试

作者用不同 shortcut 结构的 ResNet-110 在 CIFAR-10 数据集上做测试,发现最原始的(a)original 结构是最好的,也就是 identity mapping 恒等映射是最好的

5.2关于激活的尝试

 经实验发现,最好的结果是(e)full pre-activation,其次到(a)original。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/331023.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小红书云原生 Kafka 技术剖析:分层存储与弹性伸缩

面对 Kafka 规模快速增长带来的成本、效率和稳定性挑战时,小红书大数据存储团队采取云原生架构实践:通过引入冷热数据分层存储、容器化技术以及自研的负载均衡服务「Balance Control」,成功实现了集群存储成本的显著降低、分钟级的集群弹性迁…

开放式耳机2024超值推荐!教你如何选择蓝牙耳机!

开放式耳机的便利性让它在我们的日常生活中变得越来越重要。它让我们摆脱了传统耳机的限制,享受到了更多的自由。不过,市面上的开放式耳机种类繁多,挑选一款既实用又实惠的产品确实需要一些小窍门。作为一位对开放式耳机颇有研究的用户&#…

民国漫画杂志《时代漫画》第18期.PDF

时代漫画18.PDF: https://url03.ctfile.com/f/1779803-1248612707-27e56b?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps:资源来源网络!

内网穿透--Frp-简易型(速成)-上线

免责声明:本文仅做技术交流与学习... 目录 frp项目介绍: 一图通解: ​编辑 1-下载frp 2-服务端(server)开启frp口 3-kali客户端(client)连接frp服务器 4-kali生成马子 5-kali监听 6-马子执行-->成功上线 frp项目介绍: GitHub - fatedier/frp: A fast reverse proxy…

回溯大法总结

前言 本篇博客将分两步来进行,首先谈谈我对回溯法的理解,然后通过若干道题来进行讲解,最后总结 对回溯法的理解 回溯法可以看做蛮力法的升级版,它在解决问题时的每一步都尝试所有可能的选项,最终找出所以可行的方案…

从0开始实现一个博客系统 (SSM 实现)

相关技术 Spring Spring Boot Spring MVC MyBatis Html Css JS pom 文件我就不放出来了, 之前用的 jdk8 做的, MySQL 用的 5.7, 都有点老了, 你们自己看着配版本就好 实现功能 用户注册 - 密码加盐加密 (md5 加密)前后端用户信息存储 - 令牌技术用户登录 - (使用 拦截…

Xilinx(AMD) FPGA通过ICAP原语读取芯片IDCODE实现方法

1 概述 Xilinx每种型号的FPGA芯片都有一个唯一的IDCODE与之对应,同一型号不同封装的IDCODE是相同的。IDCODE的获取方法包括JTAG、ICAP原语、AXI_HWICAP IP核等。获取IDCODE常用于根据芯片型号改变代码的功能,或者对代码进行授权保护,只能在指…

从《红楼梦》的视角看大模型知识库 RAG 服务的 Rerank 调优

背景介绍 在之前的文章 有道 QAnything 源码解读 中介绍了有道 RAG 的一个主要亮点在于对 Rerank 机制的重视。 从目前来看,Rerank 确实逐渐成为 RAG 的一个重要模块,在这篇文章中就希望能讲清楚为什么 RAG 服务需要 Rerank 机制,以及如何选…

现代密码学——消息认证和哈希函数

1.概述 1.加密-->被动攻击(获取消息内容、业务流分析) 消息认证和数字签名-->主动攻击(假冒、重放、篡改、业务拒绝) 2.消息认证作用: 验证消息源的真实性, 消息的完整性(未被篡改…

集合、Collection接口特点和常用方法

1、集合介绍 对于保存多个数据使用的是数组,那么数组有不足的地方。比如, 长度开始时必须指定,而且一旦制定,不能更改。 保存的必须为同一类型的元素。 使用数组进行增加/删除元素的示意代码,也就是比较麻烦。 为…

分布式数据库HBase入门指南

目录 概述 HBase 的主要特点包括: HBase 的典型应用场景包括: 访问接口 1. Java API: 2. REST API: 3. Thrift API: 4. 其他访问接口: HBase 数据模型 概述 该模型具有以下特点: 1. 面向列: 2. 多维: 3. 稀疏: 数据存储: 数据访问: HBase 的数据模型…

你真的会使用Vue3的onMounted钩子函数吗?Vue3中onMounted的用法详解

目录 一、onMounted的前世今生 1.1、onMounted是什么 1.2、onMounted在vue2中的前身 1.2.1、vue2中的onMounted 1.2.2、Vue2与Vue3的onMounted对比 1.3、vue3中onMounted的用法 1.3.1、基础用法 1.3.2、顺序执行异步操作 1.3.3、并行执行多个异步操作 1.3.4、执行一次…

基于STM32实现智能光照控制系统

目录 引言环境准备智能光照控制系统基础代码示例:实现智能光照控制系统 光照传感器数据读取PWM控制LED亮度用户界面与显示应用场景:智能家居与农业自动化问题解决方案与优化收尾与总结 1. 引言 本教程将详细介绍如何在STM32嵌入式系统中使用C语言实现智…

纯代码如何实现WordPress搜索包含评论内容?

WordPress自带的搜索默认情况下是不包含评论内容的,不过有些WordPress网站评论内容比较多,而且也比较有用,所以想要让用户在搜索时也能够同时搜索到评论内容,那么应该怎么做呢? 网络上很多教程都是推荐安装SearchWP插…

数据结构----堆的实现(附代码)

当大家看了鄙人的上一篇博客栈后,稍微猜一下应该知道鄙人下一篇想写的博客就是堆了吧。毕竟堆栈在C语言中常常是一起出现的。那么堆是什么,是如何实现的嘞。接下来我就带大家去尝试实现一下堆。 堆的含义 首先我们要写出一个堆,那么我们就需…

基于地理坐标的高阶几何编辑工具算法(4)——线分割面

文章目录 工具步骤应用场景算法输入算法输出算法示意图算法原理 工具步骤 选中待分割面,点击“线分割面”工具,绘制和面至少两个交点的线,双击结束,执行分割操作 应用场景 快速切分大型几何面,以降低面的复杂度&…

数据结构篇其三---链表分类和双向链表

​ 前言 数据结构篇其二实现了一个简单的单链表,链表的概念,单链表具体实现已经说明,如下: 单链表 事实上,前面的单链表本质上是无头单向不循环链表。此篇说明的双向链表可以说完全反过来了了。无论是之前的单链表还…

ElasticSearch - 删除已经设置的认证密码(7.x)

文章目录 Pre版本号 7.x操作步骤检查当前Elasticsearch安全配置停止Elasticsearch服务修改Elasticsearch配置文件删除密码重启Elasticsearch服务验证配置 小结 Pre Elasticsearch - Configuring security in Elasticsearch 开启用户名和密码访问 版本号 7.x ES7.x 操作步骤 …

从ES到ClickHouse,Bonree ONE平台更轻更快!

本文字数:8052;估计阅读时间:21 分钟 作者:博睿数据 李骅宸(太道)& 娄志强(冬青) 本文在公众号【ClickHouseInc】首发 本系列第一篇内容: 100%降本增效!…

01-02.Vue的常用指令(二)

01-02.Vue的常用指令(二) 前言v-model:双向数据绑定v-model举例:实现简易计算器Vue中通过属性绑定为元素设置class 类样式引入方式一:数组写法二:在数组中使用三元表达式写法三:在数组中使用 对…