分类神经网络2:ResNet模型复现

目录

ResNet网络架构

ResNet部分实现代码


ResNet网络架构

论文原址:https://arxiv.org/pdf/1512.03385.pdf

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线,跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

残差模块如下图示:

残差结构有两种,常规残差和瓶颈残差

常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效)。

瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

ResNet网络的具体配置信息如下:

在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

ResNet部分实现代码

老样子,直接上代码,建议大家阅读代码时结合网络结构理解

import torch
import torch.nn as nn__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass ConvBN(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBN, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)def forward(self, x):x = self.conv(x)x = self.bn(x)return xdef conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):"""3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_channels, out_channels, stride=1):"""1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.convbnrelu1(x)out = self.convbn1(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(Bottleneck, self).__init__()groups = 1base_width = 64dilation = 1width = int(out_channels * (base_width / 64.)) * groups   # wide = out_channelsself.conv1 = conv1x1(in_channels, width)       # 降维通道数self.bn1 = nn.BatchNorm2d(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = nn.BatchNorm2d(width)self.conv3 = conv1x1(width, out_channels * self.expansion)   # 升维通道数self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64):super(ResNet, self).__init__()self.inplanes = 64self.dilation = 1replace_stride_with_dilation = [False, False, False]self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):downsample = Falseif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = Truelayers = nn.ModuleList()layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return layersdef forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)for layer in self.layer1:x = layer(x)for layer in self.layer2:x = layer(x)for layer in self.layer3:x = layer(x)for layer in self.layer4:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet18(num_classes, **kwargs):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)def resnet34(num_classes, **kwargs):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet50(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet101(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)def resnet152(num_classes, **kwargs):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)if __name__=="__main__":import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = resnet50(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315199.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决 Tomcat 跨域问题 - Tomcat 配置静态文件和 Java Web 服务(Spring MVC Springboot)同时允许跨域

解决 Tomcat 跨域问题 - Tomcat 配置静态文件和 Java Web 服务(Spring MVC Springboot)同时允许跨域 Tomcat 配置允许跨域Web 项目配置允许跨域Tomcat 同时允许静态文件和 Web 服务跨域 偶尔遇到一个 Tomcat 部署项目跨域问题,因为已经处理过…

Odoo:全球排名第一的免费开源PLM管理系统介绍

概述 利用开源智造OdooPLM产品生命周期管理应用,重塑创新 实现产品生命周期管理数字化,高效定义、开发、交付和管理创新的可持续产品,拥抱数字化供应链。 通过开源智造基于Odoo开源技术平台打造数字化的产品生命周期管理(PLM&am…

基于OpenMV 双轴机械臂 机器学习

文章目录 一、项目简要二、目标追踪1. 色块识别与最大色块筛选2. PID位置闭环 三、机器学习1. Device12. Device2 四、效果演示 一、项目简要 两套二维云台设备,Device1通过摄像头捕捉目标物块点位进行实时追踪,再将自身点位传到Device2,Dev…

开发 Chrome 浏览器插件入门

前言 简介 Chrome 插件是扩展 Chrome 浏览器的功能的软件程序。它们可以执行各种任务,例如阻止广告、增强隐私、添加新功能等等。 要开始编写 Chrome 插件,你需要掌握以下: 1.JavaScript语言 2.html 3.css 4.会使用chrome扩展开发手册…

opencv绘制线段------c++

绘制线段 bool opencvTool::drawLines(std::string image_p, std::vector<cv::Point> points) {cv::Mat ima cv::imread(image_p.c_str()); // 读取图像&#xff0c;替换为你的图片路径 cv::Scalar red cv::Scalar(0, 0, 255); // Red color int thickness 2;// 遍…

微信小程序开发:2.小程序组件

常用的视图容器类组件 View 普通的视图区域类似于div常用来进行布局效果 scroll-view 可以滚动的视图&#xff0c;常用来进行滚动列表区域 swiper and swiper-item 轮播图的容器组件和轮播图的item项目组件 View组件的基本使用 案例1 <view class"container"&…

Jrebel无法启动项目

使用Jrebel的debug模式无法启动项目&#xff0c;而使用Jrebel的普通模式可以。 debug模式输出日志如下&#xff0c;只有jrebel&#xff0c;没有出现系统启动日志&#xff0c;开始以为是断点&#xff0c;禁用也无法正常启动。而其他项目则jrebel 的debug可以 解决办法&#xff…

【算法刷题 | 贪心算法02】4.24(摆动序列)

文章目录 3.摆动序列3.1题目3.2解法&#xff1a;贪心3.2.1贪心思路3.2.2代码实现 3.摆动序列 3.1题目 如果连续数字之间的差严格地在正数和负数之间交替&#xff0c;则数字序列称为 摆动序列 。 第一个差&#xff08;如果存在的话&#xff09;可能是正数或负数。仅有一个元素…

【综述】DSP处理器芯片

文章目录 TI DSP C2000系列 TMS320F28003X 典型应用 开发工具链 参考资料 TI DSP TI C2000系列 控制领域 TI C5000系列 通信领域 TI C6000系列 图像领域 C2000系列 第三代集成了C28浮点DSP内核&#xff0c;采用了65nm工艺&#xff08;上一代180nm&#xff09; 第四代正在…

uniapp制作安卓原生插件踩坑

1.uniapp和Android工程互相引用讲解 uniapp原生Android插件开发入门教程 &#xff08;最新版&#xff09;_uniapp android 插件开发-CSDN博客 2.uniapp引用原生aar目录结构 详细尝试步骤1完成后生成的aar使用&#xff0c;需要新建nativeplugins然后丢进去 3.package.json示例…

数据聚类:Mean-Shift和EM算法

目录 1. 高斯混合分布2. Mean-Shift算法3. EM算法4. 数据聚类5. 源码地址 1. 高斯混合分布 在高斯混合分布中&#xff0c;我们假设数据是由多个高斯分布组合而成的。每个高斯分布被称为一个“成分”&#xff08;component&#xff09;&#xff0c;这些成分通过加权和的方式来构…

Bert基础(十八)--Bert实战:NER命名实体识别

1、命名实体识别介绍 1.1 简介 命名实体识别&#xff08;NER&#xff09;是自然语言处理&#xff08;NLP&#xff09;中的一项关键技术&#xff0c;它的目标是从文本中识别出具有特定意义或指代性强的实体&#xff0c;并对这些实体进行分类。这些实体通常包括人名、地名、组织…

小龙虾优化算法(Crayfish Optimization Algorithm,COA)

小龙虾优化算法&#xff08;Crayfish Optimization Algorithm&#xff0c;COA&#xff09; 前言一、小龙虾优化算法的实现1.初始化阶段2.定义温度和小龙虾的觅食量3.避暑阶段&#xff08;探索阶段&#xff09;4.竞争阶段&#xff08;开发阶段&#xff09;5.觅食阶段&#xff08…

高速风筒电源IC辉芒微FT8440E /FT8440A

深圳市三佛科技有限公司分享相关资料&#xff0c;高精度、高效率、低成本离线式功率开关 FT8440x是-款高性能、高精度、低成本的非隔离PWM功率开关。它包含一个专门]的电流模PWM控制器和一个高压功率开关管。内置的误差放大器经过优化保证优越的动态响应。高精度的内部分压电阻…

深度学习-线性代数

目录 标量向量矩阵特殊矩阵特征向量和特征值 标量由只有一个元素的张量表示将向量视为标量值组成的列表通过张量的索引来访问任一元素访问张量的长度只有一个轴的张量&#xff0c;形状只有一个元素通过指定两个分量m和n来创建一个形状为mn的矩阵矩阵的转置对称矩阵的转置逻辑运…

【LLMOps】小白详细教程,在Dify中创建并使用自定义工具

文章目录 博客详细讲解视频点击查看高清脑图 1. 搭建天气查询http服务1.1. flask代码1.2. 接口优化方法 2. 生成openapi json schema2.1. 测试接口2.2. 生成openapi schema 3. 在dify中创建自定义工具3.1. 导入schema3.2. 设置工具认证信息3.3. 测试工具 4. 调用工具4.1. Agent…

【JavaWeb】Day51.Mybatis动态SQL(一)

什么是动态SQL 在页面原型中&#xff0c;列表上方的条件是动态的&#xff0c;是可以不传递的&#xff0c;也可以只传递其中的1个或者2个或者全部。 而在我们刚才编写的SQL语句中&#xff0c;我们会看到&#xff0c;我们将三个条件直接写死了。 如果页面只传递了参数姓名name 字…

Multitouch 1.27.28 免激活版 mac电脑多点触控手势增强工具

Multitouch 应用程序可让您将自定义操作绑定到特定的魔术触控板或鼠标手势。例如&#xff0c;三指单击可以执行粘贴。通过执行键盘快捷键、控制浏览器的选项卡、单击鼠标中键等来改进您的工作流程。 Multitouch 1.27.28 免激活版下载 强大的手势引擎 精心打造的触控板和 Magic …

iOS 模拟请求 (本地数据调试)

简介 在iOS 的日常开发中经常会遇到一下情况&#xff1a;APP代码已编写完成&#xff0c;但后台的接口还无法使用&#xff0c;这时 APP开发就可能陷入停滞。此时iOS 模拟请求就派上用场了&#xff0c;使用模拟请求来调试代码&#xff0c;如果调试都通过了&#xff0c;等后台接口…

迁移学习基础知识

简介 使用迁移学习的优势&#xff1a; 1、能够快速的训练出一个理想的结果 2、当数据集较小时也能训练出理想的效果。 注意&#xff1a;在使用别人预训练的参数模型时&#xff0c;要注意别人的预处理方式。 原理&#xff1a; 对于浅层的网络结构&#xff0c;他们学习到的…