RESNET

ResNet

文章目录

  • ResNet
    • 主要内容
      • 开发背景
      • 解决两个问题:
        • 1. 梯度消失和梯度爆炸
        • 2. 退化问题:
      • 解决方法
        • 1. BN(Batch Normalization)层
        • 2. 残差块
    • Pytorch实现
      • BasicBlock
      • BottleNeck
      • ResNet

主要内容

开发背景

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的, 斩获2015年ImageNet竞赛中分类任务第一名, 目标检测第一名。 残差神经网络的主要贡献是发现了“退化现象(Degradation)”,并针对退化现象发明了 “直连边/短连接(Shortcut connection)”,极大的消除了深度过大的神经网络训练困难问题。神经网络的“深度”首次突破了100层、最大的神经网络甚至超过了1000层。

论文地址:Deep Residual Learning for Image Recognition

解决两个问题:

1. 梯度消失和梯度爆炸

梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0
梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

2. 退化问题:

随着层数的增加预测效果反而越来越差。

在这里插入图片描述

随着网络层数增加,出现了新的问题:退化问题,在训练集上准确率甚至下降了。这个不能解释为过拟合,因为过拟合表现为在训练集上表现更好才对。退化问题说明了深度网络不能很简单地被很好地优化。作者通过实验说明:通过浅层网络y=x 等同映射构造深层模型,结果深层模型并没有比浅层网络有更低甚至等同的错误率,推断退化问题可能是因为深层的网络很那难通过训练利用多层网络拟合同等函数。

解决方法

1. BN(Batch Normalization)层

为了解决梯度消失或梯度爆炸问题,ResNet论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。

2. 残差块

ResNet团队分别构建了带有“直连边(Shortcut Connection)”的ResNet残差块、以及降采样的ResNet残差块,区别是降采样残差块的直连边增加了一个1×1的卷积操作。对于直连边,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度。但是当维度不一致时,这就不能直接相加,通过添加1×1卷积调整通道数。这种残差学习结构可以通过前向神经网络+直连边实现, 而且整个网络依旧可以通过端到端的反向传播训练。结构如下图所示:

在这里插入图片描述

从数学角度解释:

在这里插入图片描述

深度残差网络。如果深层网络的后面那些层是恒等映射,那么模型就退化为一个浅层网络。所以要解决的就是学习恒等映射函数。但是直接让一些层去拟合一个潜在的恒等映射函数H(x) = x,比较困难,这可能就是深层网络难以训练的原因。但是,如果把网络设计为H(x) = F(x) + x。我们可以转换为学习一个残差函数F(x) = H(x) - x. 只要F(x)=0,就构成了一个恒等映射H(x) = x. 此外,拟合残差会更加容易。

总的来说,一是其导数总比原导数加1,这样即使原导数很小时,也能传递下去,能解决梯度消失的问题; 二是y=f(x)+x式子中引入了恒等映射(当f(x)=0时,y=x),解决了深度增加时神经网络的退化问题。

Pytorch实现

ResNet实现cifar100分类的代码放在GitHub: pytorch-cifar100了。该部分代码在项目中models/resnet.py里面。

参考github项目pytorch-cifar100

在ResNet中最重要的就是残差结构,一共提供了两种:

  • BasicBlock:两层3 * 3卷积用于实现18-layer和34-layer
  • BottleNeck:用于实现更深层的网络,如50-layer, 101-layer, 152-layer

BasicBlock

在这里插入图片描述

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, *args, **kwargs) -> None:super().__init__(*args, **kwargs)#residual functionself.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=stride,padding=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))self.shortcut = nn.Sequential()# 判断输出输入维度是否一致,不一致则使用1 * 1卷积进行升维或降维。 if stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels*BasicBlock.expansion))def forward(self,x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

BottleNeck

在这里插入图片描述

class BottleNeck(nn.Module):expansion = 4 def __init__(self, in_channels, out_channels, stride=1, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels*BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))self.shortcut = nn.Sequential()if stride != 1 or in_channels != BottleNeck.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels*BottleNeck.expansion))def forward(self,x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

ResNet

最后我们按照下图的网络结构来构建ResNet

在这里插入图片描述

class ResNet(nn.Module):def __init__(self, block, num_block, num_classes = 100, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.inchannels = 64self.conv1 = nn.Sequential(nn.Conv2d(3,64, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True))self.conv2_x = self._maker_layer(block, 64, num_block[0], 1)self.conv3_x = self._maker_layer(block, 128, num_block[1], 2)self.conv4_x = self._maker_layer(block, 256, num_block[2], 2)self.conv5_x = self._maker_layer(block, 512, num_block[3], 2)self.avg_pool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _maker_layer(self,block, out_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks -1)layers = []for stride in strides:layers.append(block(self.inchannels, out_channels, stride))self.inchannels = out_channels * block.expansionreturn nn.Sequential(*layers)def resnet18():return ResNet(BasicBlock, [2, 2, 2, 2])def resnet34():return ResNet(BasicBlock, [3, 4, 6, 3])def resnet50():return ResNet(BottleNeck, [3, 4, 6, 3])def resnet101():return ResNet(BottleNeck, [3, 4, 23, 3])def resnet152():return ResNet(BottleNeck, [3, 8, 36, 3]) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/387406.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LLM大模型:十大人工智能大模型技术介绍

十大人工智能大模型技术的简介: 深度学习模型 深度学习是人工智能领域中一种重要的机器学习技术,通过构建深度神经网络来模拟人脑的认知过程。深度学习模型能够自动提取数据的特征,并在海量数据中进行学习和优化,从而在语音识别…

搭建规范化的vue2项目

项目包含的库 Vue2VuexRouterEslintPrettier 环境 vue:2.6.14 eslint:7.32.0 prettier:2.4.1 eslint-plugin-prettier:4.0.0 eslint-plugin-vue:8.0.3 vue/cli:5.0.8 步骤 全局安装cli工具 npm in…

CAE仿真软件怎样下载和安装?

CAE仿真软件是一类专业工程软件,通过数值计算和仿真技术,帮助工程师和设计师在计算机上模拟和分析各种工程问题,如结构强度、热传导、流体力学等,从而优化产品设计、减少试验成本,提高产品性能和质量。HyperWorks是常见…

快手文生图模型-Kolors快速上手

Kolors是什么 可图(Kolors):用于真实感文本到图像合成的扩散模型的有效训练 可图,是快手开源的一个文生图模型,架构上使用了chatglm,比普通的sd模型在中文理解上要强大很多,以往sd模型的提示词理解能力往往只有两种 …

二进制部署k8s集群之cni网络插件flannel和calico工作原理

3、部署 CNI 网络组件 在 master01 节点上操作 上传flannel-v0.21.5.zip并解压 unzip flannel-v0.21.5.zipscp flannel*.tar 192.168.80.20:/opt/k8s/ scp flannel*.tar 192.168.80.30:/opt/k8s/ node两个节点操作 cd /opt/k8s/ docker load -i flannel.tar docker load -i …

nginx代理设置时能获取到源IP地址的方法

nginx通过http_x_forwarded_for限制来访IP示例_ngnix 根据header的x-forwarded-for限制接入-CSDN博客 名称ip客户端地址10.0.23.90nginx服务器地址110.0.202.48:18888,代理到10.0.204.82:8888nginx服务器地址210.0.204.82:8888,代理到10.0.204.82:8887后…

自写ApiTools工具,功能参考Postman和ApiPost

近日在使用ApiPost的时候,发现新版本8和7不兼容,也就是说8不支持离线操作,而7可以。 我想说,我就是因为不想登录使用才从Postman换到ApiPost的。 众所周知,postman时国外软件,登录经常性抽风,…

leetcode 1555 银行账号概要(postgresql)

需求 用户表: Users --------------------- | Column Name | Type | --------------------- | user_id | int | | user_name | varchar | | credit | int | --------------------- user_id 是这个表的主键。 表中的每一列包含每一个用户当前的额度信息。 交易表&…

使用 Elastic Observability 中的 OpenTelemetry 进行基础设施监控

作者:来自 Elastic ISHLEEN KAUR 将 OpenTelemetry 与 Elastic Observability 相结合,形成应用程序和基础设施监控解决方案。 在 Elastic,我们最近决定全面采用 OpenTelemetry 作为首要的数据收集框架。作为一名可观察性工程师,我…

分享5款ai头像工具,助你轻松实现社交新形象

如今,无论是社交媒体上的个人形象塑造,还是虚拟世界中的角色扮演,一个独特而吸引人的AI头像都能成为你个性化的代表。 例如,ai头像男古风通常代表着一种对传统文化的尊重和热爱;而现代简约头像可能代表着一种追求简洁…

Mongodb集合操作

文章目录 1、进入容器2、如果数据库不存在,则创建数据库,否则切换到指定数据库3、在 MongoDB 中,创建集合不是必须操作。当你插入一些文档时,MongoDB 会自动创建集合。4、查看数据库列表5、查看集合6、显示创建集合7、删除集合 1、…

百度竞价托管如何判断关键词出价是否偏高

在百度竞价推广中,关键词出价的高低直接影响着广告的展示位置、点击率以及最终的转化效果。然而,过高的出价不仅会增加推广成本,还可能导致预算的浪费。因此,作为百度竞价托管 www.pansem.com 的专业团队,如何准确判断…

springboot校园跑腿服务系统-计算机毕业设计源码15157

摘要 本文介绍了一种基于Springboot和uniapp的校园跑腿服务系统的设计与实现。该系统旨在为大学校园提供一种方便快捷的跑腿服务,满足学生和教职员工的日常需求。首先,系统采用了Springboot作为后端框架,利用其轻量级、高效的特性&#xff0c…

httpx,一个网络请求的 Python 新宠儿

大家好!我是爱摸鱼的小鸿,关注我,收看每期的编程干货。 一个简单的库,也许能够开启我们的智慧之门, 一个普通的方法,也许能在危急时刻挽救我们于水深火热, 一个新颖的思维方式,也许能…

计算机网络-七层协议栈介绍

之前介绍了网络世界的构成,从宏观角度介绍了网络设备和网络架构,链接: link,但是这种认识过于粗糙,过于肤浅。网络本质上是用于主机之间的通信,是端对端的连接通信,两台计算机可能距离很远,主机…

thinkPHP开发的彩漂网站源码,含pc端和手机端

源码简介 后台thinkPHP架构,页面程序双分离,Mysql数据库严谨数据结构、多重数据审核机制、出票机制和监控机制,html5前端技术适用移动端,后台逻辑更多以server接口可快捷实现对接pc和ap,下载会有少量图片素材丢失,附件有下载说明前端demo账户密码和后台管理地址管理员账户密码…

C 语言动态链表

线性结构->顺序存储->动态链表 一、理论部分 从起源中理解事物,就是从本质上理解事物。 -杜勒鲁奇 动态链表是通过结点(Node)的集合来非连续地存储数据,结点之间通过指针相互连接。 动态链表本身就是一种动态分配内存的…

Java 8-函数式接口

目录 一、概述 二、 函数式接口作为方法的参数 三、函数式接口作为方法的返回值 四、 常用的函数式接口 简单总结 简单示例 4.1 Consumer接口 简单案例 自我练习 实际应用场景 多线程处理 4.2 Supplier接口 简单案例 自我练习 实际应用场景 配置管理 4.3 Func…

TypeError: Components is not a function

Vue中按需引入Element-plus时,报错TypeError: Components is not a function。 1、参考Element-plus官方文档 安装unplugin-vue-components 和 unplugin-auto-import这两款插件 2、然后需要在vue.config.js中配置webPack打包plugin配置 3、重新启动项目会报错 T…

Java----反射

什么是反射? 反射就是允许对成员变量、成员方法和构造方法的信息进行编程访问。换句话来讲,就是通过反射,我们可以在不需要创建其对象的情况下就可以获取其定义的各种属性值以及方法。常见的应用就是IDEA中的提示功能,当我…