【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. Residual(残差连接)

__init__(初始化)

forward(前向传播)

2. resnet_block(残差网络块)

3. ResNet(网络模型)

__init__(初始化)

forward(前向传播)

4. 代码整合


一、实验介绍

        本实验实现了实现深度残差神经网络ResNet

        残差网络(ResNet)是一种深度神经网络架构,用于解决深层网络训练过程中的梯度消失和梯度爆炸问题。通过引入残差连接(residual connection)来构建网络层与层之间的跳跃连接,使得网络可以更好地优化深层结构。

        残差网络的一个重要应用是在图像识别任务中,特别是在深度卷积神经网络(CNN)中。通过使用残差模块,可以构建非常深的网络,例如ResNet,其在ILSVRC 2015图像分类挑战赛中取得了非常出色的成绩。

        在ResNet中,每个残差块由一个或多个卷积层组成,其中包含了跳跃连接。跳跃连接将输入直接添加到残差块的输出中,从而使得网络可以学习残差函数,即残差块只需学习将输入的变化部分映射到输出,而不需要学习完整的映射关系。这种设计有助于减轻梯度消失问题,使得网络可以更深地进行训练。

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

from torch import nn
import torch.nn.functional as F

1. Residual(残差连接)

class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层,将会在第7章讲到self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

__init__(初始化)

  • 参数:
    • 输入通道数`input_channels`
    • 输出通道数`num_channels`
    • 是否使用1x1卷积`use_1x1conv`
    • 步幅`strides`
  • 在初始化过程中,创建了两个卷积层`conv1`和`conv2`,分别使用不同的输入和输出通道数,并指定了卷积核的大小、填充和步幅。
  • 如果`use_1x1conv`为True,则创建一个1x1卷积层`conv3`,用于进行维度匹配;
  • 否则,将`conv3`设为None。
  • 创建两个批量归一化层`bn1`和`bn2`,用于对卷积层的输出进行批量归一化操作。

forward(前向传播)

  • 将输入`X`通过`conv1`进行卷积操作,然后经过批量归一化层`bn1`和ReLU激活函数。
  • 将输出通过`conv2`进行卷积操作,再经过批量归一化层`bn2`。
  • 如果`conv3`不为None,则将输入`X`通过`conv3`进行卷积操作,用于进行维度匹配。
  • 最后,将经过卷积和批量归一化的结果与输入相加,得到残差连接的输出。
  • 通过ReLU激活函数处理输出,并返回结果。

2. resnet_block(残差网络块)

        生成由多个残差块组成的残差网络块。

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk
  • 参数
    • input_channels:输入通道数,即每个残差块的输入的通道数。
    • num_channels:每个残差块中卷积层的输出通道数,也是每个残差块内部卷积层的通道数。
    • num_residuals:残差块的数量。
    • first_block:一个布尔值,表示是否为整个 ResNet 中的第一个残差块。
  • 创建一个空列表 blk,用于存储构建的残差块。
  • 通过一个循环迭代 num_residuals 次,每次迭代都构建一个残差块并将其添加到 blk 列表中。
    • 在每个迭代中,首先检查是否为第一个残差块且 first_block 为 False。
      • 如果是,则创建一个具有下采样(strides=2)的残差块,并将其添加到 blk 列表中。这是为了在整个 ResNet 中的第一个残差块中进行下采样。
      • 如果不是第一个残差块或者 first_block 为 True,则创建一个普通的残差块,并将其添加到 blk 列表中。
  • 返回构建好的残差块列表 blk

3. ResNet(网络模型

        ResNet 网络模型,包含了多个残差块,用于实现图像分类任务。

class ResNet(nn.Module):def __init__(self, num_classes):super().__init__()self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))self.b3 = nn.Sequential(*resnet_block(64, 128, 2))self.b4 = nn.Sequential(*resnet_block(128, 256, 2))self.b5 = nn.Sequential(*resnet_block(256, 512, 2))self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))def forward(self, x):net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)return net(x)

__init__(初始化)

  • 参数:
    • num_classes,表示分类的类别数目
  • 调用父类的构造函数 `super().__init__()`。
  • self.b1是一个包含了卷积层、批归一化层、ReLU激活函数和最大池化层的序列。它对输入数据进行卷积操作,然后进行批归一化、ReLU激活和最大池化,用于提取输入图像的特征。
    • nn.Conv2d,使用 7x7 的卷积核对输入进行卷积操作,输出通道数为 64,步长为 2,填充为 3。
    • nn.BatchNorm2d 层,用于进行批归一化操作。
    •  ReLU 激活函数层 nn.ReLU()。
    • nn.MaxPool2d`层,使用 3x3 的池化核进行最大池化操作,步长为 2,填充为 1。
  • self.b2self.b3self.b4self.b5分别是几个残差块(resnet_block)的序列。这些残差块包含了卷积层、批归一化层和ReLU激活函数,用于进一步提取输入数据的特征。
    • self.b2使用构建了 2 个残差块,输入通道数为 64,输出通道数也为 64,并且指定 `first_block=True`,表示它是第一个残差块;
    • ……
  • self.head是一个包含自适应平均池化层(AdaptiveAvgPool2d)、展平层(Flatten)和全连接层(Linear)的序列。它将输入数据进行自适应平均池化,然后展平为一维向量,并通过全连接层将特征映射到分类的类别数目上:
    • 自适应平均池化层nn.AdaptiveAvgPool2d:将输入的特征图池化为大小为 1x1 的特征图。
    • 展平层nn.Flatten,将池化后的特征图展平成一维向量。
    • 全连接层nn.Linear,将展平后的特征映射到输出类别的数量。

forward(前向传播)

        输入数据通过上述序列模块self.b1self.b2self.b3self.b4self.b5self.head进行处理,最终输出分类结果

4. 代码整合

# 导入必要的工具包
from torch import nn
import torch.nn.functional as F#  残差连接, 输入和输出的维度有时是相同的, 有时是不同的, 所以需要 use_1x1conv来判断是否需要
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层,将会在第7章讲到self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)# 残差网络是由几个不同的残差块组成的
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkclass ResNet(nn.Module):def __init__(self, num_classes):super().__init__()self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))self.b3 = nn.Sequential(*resnet_block(64, 128, 2))self.b4 = nn.Sequential(*resnet_block(128, 256, 2))self.b5 = nn.Sequential(*resnet_block(256, 512, 2))self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))def forward(self, x):net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)return net(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/153283.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML基础入门01

目录 1.HTML基础 1.1HTML标签 1.2HTML 文件基本结构 1.3标签层次结构 1.4快速生成代码框架 2.HTML 常见标签 2.1注释标签 2.2标题标签: h1-h6 2.3段落标签: p 2.4.换行标签: br 3.综合案例: 展示博客 1.HTML基础 1.1HTML标签 HTML 代码是由 "标签" 构成…

0基础学习VR全景平台篇 第104篇:720全景后期软件安装

上课!全体起立~ 大家好,欢迎观看蛙色官方系列全景摄影课程! 摄影进入数码时代,后期软件继承“暗房工艺”,成为摄影师表达内在情感的必备工具。 首先说明,全景摄影与平面摄影的一个显著的区别是全景图片需…

MyCat-web安装文档:安装Zookeeper、安装Mycat-web

安装Zookeeper A. 上传安装包 zookeeper-3.4.6.tar.gzB. 解压 #解压到当前目录,之后会生成一个安装后的目录 tar -zxvf zookeeper-3.4.6.tar.gz#加上-c 代表解压到指定目录 tar -zxvf zookeeper-3.4.6.tar.gz -C /usr/local/C. 在安装目录下,创建数据…

CasA:用于点云 3D 目标检测的级联注意力网络

论文摘要 LiDAR 收集的数据通常表现出稀疏和不规则的分布。 3D 空间中的 LiDAR 扫描并不均匀。近处和远处的物体之间存在巨大的分布差距。 CasA(Cascade Attention) 由 RPN(Region proposal Network)和 CRN(cascade refinement Network&…

uni-app:引入echarts(使用renderjs)

效果 代码 <template><view click"echarts.onClick" :prop"option" :change:prop"echarts.updateEcharts" id"echarts" class"echarts"></view> </template><script>export default {data()…

用netty实现简易rpc

文章目录 rpc介绍&#xff1a;rpc调用流程:代码&#xff1a; rpc介绍&#xff1a; RPC是远程过程调用&#xff08;Remote Procedure Call&#xff09;的缩写形式。SAP系统RPC调用的原理其实很简单&#xff0c;有一些类似于三层构架的C/S系统&#xff0c;第三方的客户程序通过接…

谷歌 Chrome 浏览器正推进“追踪保护”功能

导读近日消息&#xff0c;根据国外科技媒体 Windows Latest 报道&#xff0c;谷歌计划在 Chrome 浏览器中推进“追踪保护”&#xff08;Tracking Protection&#xff09;功能&#xff0c;整合浏览器现有隐私功能&#xff0c;保护用户被网站跟踪。 根据一项 Chromium 提案&…

unity 控制玩家物体

创建场景 放上一个plane&#xff0c;放上一个球 sphere&#xff0c;假定我们的球就是我们的玩家&#xff0c;使用控制键w a s d 来控制球也就是玩家移动。增加一个材质&#xff0c;把颜色改成绿色&#xff0c;把材质赋给plane&#xff0c;区分我们增加的白球。 增加组件和脚…

06_Node.js服务器开发

1 服务器开发的基本概念 1.1 为什么学习服务器开发 Node.js开发属于服务器开发&#xff0c;那么作为一名前端工程师为什么需要学习服务器开发呢&#xff1f; 为什么学习服务器开发&#xff1f; 能够和后端程序员更加紧密配合网站业务逻辑前置扩宽知识视野 1.2 服务器开发可…

全息投影技术服务公司【盟云全息】收入急剧下降,存在风险

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 全息投影技术服务公司【盟云全息】MicroCloud Hologram(HOLO)的股价表现今年以来异常出色&#xff0c;年初至今已经上涨了334%以上&#xff0c;猛兽财经将在本文中分析MicroCloud Hologram股价上涨的原因&#xff0c;以及它…

京东获得店铺的所有商品 API接口分享

item_search_shop-获得店铺的所有商品 公共参数 注册API测试key 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,…

深眸科技自研AI视觉分拣系统,实现物流行业无序分拣场景智慧应用

在机器视觉应用环节中&#xff0c;物体分拣是建立在识别、检测之后的一个环节&#xff0c;通过机器视觉系统对图像进行处理&#xff0c;并结合机械臂的使用实现产品分类。 通过引入视觉分拣技术&#xff0c;不仅可以实现自动化作业&#xff0c;还能提高生产线的生产效率和准确…

react管理系统layOut简单搭建

一、新建立react文件夹&#xff0c;生成项目 npx create-react-app my-app cd my-app npm start 二、安装react-router-dom npm install react-router-dom 三、安装Ant Design of React&#xff08;UI框架库&#xff0c;可根据需求进行安装&#xff09; npm install antd …

短剧是什么?短剧平台的经营模式。短剧平台需要什么资质?

​ 短剧&#xff0c;顾名思义&#xff0c;是一种时长较短的戏剧表演形式。它通常以一个简短的故事或情境为主线&#xff0c;通过角色的表演&#xff0c;展现出一段生动、有趣或富有深意的故事。 短剧的时长通常较短&#xff0c;一般在几分钟到半小时之间&#xff0c;但这并不是…

Nginx + PHP 异常排查,open_basedir 异常处理

新上一个网站&#xff0c;通过域名访问失败&#xff0c;排查方法如下&#xff1a; 开启异常日志 开启域名下&#xff0c;nginx的异常日志&#xff0c;并查看日志 tail -f /var/log/nginx/nginx.localhost.error.log开启php的异常日志&#xff0c;该配置位于php.ini文件下 …

青菜餐饮公司资产负债表之二

一、背景 上一篇文章 青菜餐饮公司资产负债表之一 用一个案例来引导学习资产负债表&#xff0c;要做好资产负债表&#xff0c;最关键的是找准科目&#xff0c;左右平衡。资产和负债的概念基本上都比较好理解&#xff0c;对于没有接触过财务知识的人&#xff0c;所有者权益相对…

nginx+HTTPS证书

申请ssl下载证书 阿里云购买免费证书&#xff0c;可免费申请20个&#xff0c;需要配置域名&#xff0c;域名为单个域名&#xff0c;比如www.xxx.com&#xff0c;必须带前缀。 申请完之后需要创建证书 注&#xff1a;创建证书时阿里云购买的域名可以直接给配好解析&#xff0…

自监督DINO论文笔记

论文名称&#xff1a;Emerging Properties in Self-Supervised Vision Transformers 发表时间&#xff1a;CVPR2021 作者及组织&#xff1a; Facebook AI Research GitHub&#xff1a;https://github.com/facebookresearch/dino/tree/main 问题与贡献 作者认为self-supervise…

Linux系列---【查看mac地址】

查看mac地址命令 查看所有网卡命令 nmcli connection show 查看物理网卡mac地址 ifconfig 删除网卡 nmcli connection delete virbr0 禁用libvirtd.service systemctl disable libvirtd.service 启用libvirtd.service systemctl enable libvirtd.service

基于Spring Boot的网上租贸系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技…