霹雳吧啦Wz《pytorch图像分类》-p5ResNet网络

《pytorch图像分类》p5ResNet网络结构

  • 1 网络中的亮点
    • 1.1 超深的网络结构
    • 1.2 residual模块
    • 1.3 Batch Normalization
    • 1.4 迁移学习简介
  • 2 模块类代码
    • 2.1 BasicBlock(18 & 32 layers)
    • 2.2 Bottleneck(50 & 101 & 152 layers)
    • 2.3 ResNet
  • 3 课程代码
    • 3.1 modle.py
    • 3.2 train.py
    • 3.3 predict.py

1 网络中的亮点

论文连接:Deep Residual Learning for Image Recognition

1.1 超深的网络结构

AlexNet,VggNet等之前学习的网络当中,深度只有十几层到二十几层,如果层数加深,会产生梯度消失或者梯度爆炸现象,错误率反而会增高(如左图),ResNet是添加了残差模块之后才能搭建足够深的网络模型(如右图)。
在这里插入图片描述
网络深度突破了1000层
在这里插入图片描述

1.2 residual模块

残差模块(Residual module)是一种常用于深度卷积神经网络(CNN)中的构建块,用于解决梯度消失或爆炸的问题,并帮助网络更有效地进行训练。
残差模块的核心思想是通过引入跳跃连接 shortcut connections(论文中写的术语)来绕过一些卷积层,将输入直接添加到输出中

在这里插入图片描述
当输入特征图和输出特征图在维度上不匹配时,可以通过1x1卷积层调整维度,以确保它们可以相加。
在这里插入图片描述
Down sampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2.
下采样由conv3_1、conv4_1和conv5_1执行,步长为2。
在这里插入图片描述

1.3 Batch Normalization

Batch Normalization(批归一化)是一种在深度神经网络中用于提高训练速度和稳定性的技术。它通过对每个小批次(batch)的输入数据进行归一化处理,以使输入的均值和方差保持在稳定范围内

举个例子:我们在烹饪中做蛋糕的过程。每次用相同的配方烤蛋糕时,我们都会按照配方中的比例加入相同的材料,所以无论做几个蛋糕,它们都会有相似的味道和口感。
类似地,神经网络的训练过程中,我们可以将输入数据看作是蛋糕的材料。批归一化通过对每个批次中的数据进行标准化,使得它们具有零均值和单位方差。这相当于将每个批次的数据转化为类似的“蛋糕材料”,这样神经网络在处理不同的批次时,每个批次的数据都可以在相似的范围内变化。这样有利于网络的学习和训练。
(我还得找资料更深一步去学习)

1.4 迁移学习简介

  1. 能够快速的训练出一个理想的结果
  2. 当数据集比较小的时候也能训练出理想的效果

常见的迁移学习方式:

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层

2 模块类代码

2.1 BasicBlock(18 & 32 layers)

针对18层和34层的残差结构
expansion是指每个残差块的输出通道数相对于输入通道数的倍数,主分支所采用的卷积核的个数有没有发生变化,expansion=1也就是卷积核个数没有发生变化。
在这里插入图片描述

downsample=None对应的是实线的结构
虚线结构即是跳跃连接shortcut connections,见本篇博客1.2residual结构的第二张图,比如conv3、conv4和conv5的第一层

def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):

在这里插入图片描述

2.2 Bottleneck(50 & 101 & 152 layers)

针对50层、101层和152层的残差结构
expansion=4即输出通道数是输入通道数的4倍。
在这里插入图片描述

def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):

2.3 ResNet

在这里插入图片描述

class ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):

3 课程代码

3.1 modle.py

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

3.2 train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdmfrom model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = resnet34()# load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path = "resnet34-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# for param in net.parameters():#     param.requires_grad = False# change fc layer structurein_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5)net.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 3best_acc = 0.0save_path = './resNet34.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

3.3 predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "1.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = resnet34(num_classes=5).to(device)# load model weightsweights_path = "./resNet34.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))# predictionmodel.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

1.jpg预测的玫瑰花
在这里插入图片描述
2.jpg预测的雏菊
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/230390.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue-cli创建项目时由esLint校验导致报错或警告的问题及解决

vue-cli创建项目时由esLint校验导致报错或警告的问题及解决 一、万能办法 一、万能办法 //就是在报错的JS文件中第一行写上 /* eslint-disable */链接: https://www.yii666.com/blog/288808.html 其它的方法我遇见了再补充

docker的安装的详细教程,以及出现错的解决办法(阿里云)

docker的安装与使用 1.安装dnf sudo yum -y install dnf Repository extras is listed more than once in the configuration 错误:无法为仓库 appstream 找到一个有效的 baseurl 出现这个错误这是由于阿里云的版本导致的 在阿里云开发者社区有答案&#xff01…

什么是软件安全性测试?如何进行安全测试?

一、什么是软件安全性测试? 软件安全性测试是指对软件系统中的安全漏洞进行检测和评估的过程。其目的是为了确保软件系统在面对各种安全威胁时能够保持其功能的完整性、可用性和机密性。 二、软件安全性测试可以通过以下几个步骤来进行: 1. 需求分析&a…

Django 学习教程- Hello world入门案例

系列 Django 学习教程-介绍与安装-CSDN博客 欢迎来到第Djagno学习教程第二章Hello World 入门案例。 在本教程中,我将引导您完成django的Hello World入门案例。 让我们开始吧! 版本 Django 5.0Python 3.10 创建项目 安装 Django 之后&#xff0…

数字孪生与物联网(IoT)技术的结合

数字孪生与物联网(IoT)技术的结合可以在多个领域实现更智能、更高效的应用。以下是数字孪生在物联网技术中的一些应用,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作。 1.实时监…

把苹果手机上的备忘录转为长图片,分享给别人方法教程

在这个信息爆炸的时代,手机备忘录几乎成了我随身携带的“记忆宝库”。每当我脑海中闪现出一个想法、灵感或是需要记住的重要事项,我都会第一时间打开苹果手机的备忘录,将它们一一记录下来。备忘录的简洁界面和高效操作总能让我在忙碌的生活中…

gradle --腾讯国内镜像源

distributionUrlhttps\://mirrors.cloud.tencent.com/gradle/gradle-7.3.3-bin.zip 1.进入到自己工程目录下的wrapper文件夹。 2.编辑gradle-wrapper文件 使用https://mirrors.cloud.tencent.com/gradle/gradle-4.6-all.zip来代替原来的 https\://services.gradle.org/distri…

CDD文件的制作

CDD文件 1、核查诊断调查表2、制作CDD3、Diva测试 1、核查诊断调查表 ECU级别:包括文档相关、控制器的诊断ID和时间参数,支持的服务,DTC、DID、刷写流程。 2、制作CDD 2.1、cddt编辑思路(每一步都要根据调查表进行操作&#xf…

QT C++调用python传递RGB图像和三维数组,并接受python返回值(图像)

目的: 用QT调用python代码,将QT读取的图像(Qimage)作为参数传入python中,将QT的三维数组作为参数传递给python,python接收QT传入的图像进行计算,将结果返回给QT并显示。 一 .pro 头文件的配置,和lib库的…

很实用的ChatGPT网站—在线编程模块增补篇

很实用的ChatGPT网站(http://chat-zh.com/)——增补篇 今天介绍一个好兄弟开发的ChatGPT网站,网址[http://chat-zh.com/]。这个网站功能模块很多,包含生活、学习、医疗、法律、经济等很多方面。今天跟大家分享一下,新…

Vue:Vue与VueComponent的关系图

1.一个重要的内置关系&#xff1a;VueComponent.prototype.proto Vue.prototype 2.为什么要有这个关系&#xff1a;让组件实例对象&#xff08;vc&#xff09;可以访问到 Vue原型上的属性、方法。 案例证明&#xff1a; <!DOCTYPE html> <html lang"en"&…

errors包返回堆栈信息的性能测试

errors包返回堆栈信息的性能测试 上一篇Golang中使用errors返回调用堆栈信息 讲了使用第三方开源库的errors github.com/go-errors/errors&#xff0c;错误信息带调用栈&#xff0c;方便定位错误的抛出位置。 通过堆栈的信息来定位是方便了&#xff0c;性能怎么样&#xff0c…

【计算机算法设计与分析】n皇后问题(C++_回溯法)

文章目录 题目描述测试样例算法原理算法实现参考资料 题目描述 在nxn格的棋盘上放置彼此不受攻击的n格皇后。按照国际象棋的规则&#xff0c;皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。n后问题等价于在nxn格的棋盘上放置n个皇后&#xff0c;任何2个皇后不放在同…

智能分析网关V4智慧港口码头可视化视频智能监管方案

一、需求背景 近年来&#xff0c;水利港口码头正在进行智能化建设&#xff0c;现场管理已经是重中之重。港口作为货物、集装箱堆放及中转机构&#xff0c;具有昼夜不歇、天气多变、环境恶劣等特性&#xff0c;安全保卫工作显得更加重要。港口码头的巡检现场如何高效、快捷地对…

设计模式篇章(1)——理论基础

设计模式&#xff1a;在软件开发中会面临许多不断重复发生的问题&#xff0c;这些问题可能是代码冗余、反复修改旧代码、重写以前的代码、在旧代码上不断堆新的代码&#xff08;俗称屎山&#xff09;等难以扩展、不好维护的问题。因此1990年有四位大佬&#xff08;GoF组合&…

连接GaussDB(DWS)报错:Invalid or unsupported by client SCRAM mechanisms

用postgres方式连接GaussDB(DWS)报错&#xff1a;Invalid or unsupported by client SCRAM mechanisms 报错内容 [2023-12-27 21:43:35] Invalid or unsupported by client SCRAM mechanisms org.postgresql.util.PSQLException: Invalid or unsupported by client SCRAM mec…

Tinker 环境下数据表的用法

如果我们要自己手动创建一个模型文件&#xff0c;最简单的方式是通过 make:model 来创建。 php artisan make:model Article 删除模型文件 rm app/Models/Article.php 创建模型的同时顺便创建数据库迁移 php artisan make:model Article -m Eloquent 表命名约定 在该文件中&am…

MySQL基础篇(一)SQL

视频地址: 黑马程序员 MySQL数据库入门到精通&#xff0c;从mysql安装到mysql高级、mysql优化全囊括 SQL&#xff0c;全称 Structured Query Language&#xff0c;结构化查询语言。操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库统一 标准。 一、SQL通用语…

python练习题

1. 找出1到20内的所有质数 提示&#xff1a;质数是指大于1的自然数&#xff0c;除了1和它本身以外没有任何正因数&#xff08;除了1和它本身外不能被其他整数整除&#xff09;。换句话说&#xff0c;质数是只有两个正因数的数&#xff0c;这两个因数就是1和它自己。 for num …

CAN通信-报文信号格式(Inter、Motorola)

DBC 1、Inter格式和Motorola格式2、制作DBC 1、Inter格式和Motorola格式 Inter格式(小端模式)&#xff1a;高位字节存放在高地址中&#xff0c;低位字节存放在低地址中&#xff0c;数据表现&#xff1a;以一个字节为例&#xff0c;前半个字节为地位。 Motorola格式(大端模式)&…