《动手学深度学习(PyTorch版)》笔记7.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.6 Residual Networks(ResNet)

随着我们设计越来越深的网络,深刻理解“新添加的层如何提升神经网络的性能”变得至关重要。

7.6.1 Function Class

首先,假设有一类特定的神经网络架构 F \mathcal{F} F,它包括学习速率和其他超参数设置。对于所有 f ∈ F f \in \mathcal{F} fF,存在一些参数集(例如权重和偏置),这些参数可以通过在合适的数据集上进行训练而获得。现在假设 f ∗ f^* f是我们真正想要找到的函数,如果是 f ∗ ∈ F f^* \in \mathcal{F} fF,那我们可以轻而易举的训练得到它,但通常我们不会那么幸运。我们将尝试找到一个函数 f F ∗ f^*_\mathcal{F} fF,这是我们在 F \mathcal{F} F中的最佳选择。例如,给定一个具有 X \mathbf{X} X特性和 y \mathbf{y} y标签的数据集,我们可以尝试通过解决以下优化问题来找到它:

f F ∗ : = a r g m i n f L ( X , y , f ) ,  f ∈ F . f^*_\mathcal{F} := \mathop{\mathrm{argmin}}_f L(\mathbf{X}, \mathbf{y}, f) \text{ , } f \in \mathcal{F}. fF:=argminfL(X,y,f) , fF.

为了得到更近似真正 f ∗ f^* f的函数,唯一合理的可能性是设计一个更强大的架构 F ′ \mathcal{F}' F。换句话说,我们预计 f F ′ ∗ f^*_{\mathcal{F}'} fF f F ∗ f^*_{\mathcal{F}} fF“更近似”。然而,如果 F ⊈ F ′ \mathcal{F} \not\subseteq \mathcal{F}' FF,则无法保证新的体系“更近似”。事实上, f F ′ ∗ f^*_{\mathcal{F}'} fF可能更糟:如下图所示,对于非嵌套函数(non-nested function)类,较复杂的函数类并不总是向“真”函数 f ∗ f^* f靠拢(复杂度由 F 1 \mathcal{F}_1 F1 F 6 \mathcal{F}_6 F6递增)。在下图的左边,虽然 F 3 \mathcal{F}_3 F3 F 1 \mathcal{F}_1 F1更接近 f ∗ f^* f,但 F 6 \mathcal{F}_6 F6却离的更远了。相反,对于下图右边的嵌套函数(nested function)类 F 1 ⊆ … ⊆ F 6 \mathcal{F}_1 \subseteq \ldots \subseteq \mathcal{F}_6 F1F6,我们可以避免上述问题。
在这里插入图片描述

因此,只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function) f ( x ) = x f(\mathbf{x}) = \mathbf{x} f(x)=x,新模型和原模型将同样有效。同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。针对这一问题,何恺明等人提出了残差网络(ResNet)。其核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。于是,残差块(residual blocks)便诞生了,这个设计对如何建立深层神经网络产生了深远的影响。

7.6.2 Residual Blocks

在这里插入图片描述

如上图所示,假设我们的原始输入为 x x x,而希望学出的理想映射为 f ( x ) f(\mathbf{x}) f(x)。上图左边是一个正常块,虚线框中的部分需要直接拟合出该映射 f ( x ) f(\mathbf{x}) f(x),而右边是ResNet的基础架构–残差块(residual block),虚线框中的部分则需要拟合出残差映射 f ( x ) − x f(\mathbf{x}) - \mathbf{x} f(x)x。残差映射在现实中往往更容易优化。以恒等映射作为理想映射 f ( x ) f(\mathbf{x}) f(x),只需将上图右边虚线框内上方的加权运算(如仿射)的权重和偏置参数设成0,那么 f ( x ) f(\mathbf{x}) f(x)即为恒等映射。实际上,当理想映射 f ( x ) f(\mathbf{x}) f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。在残差块中,输入可通过跨层数据线路更快地向前传播,且可以避免某些梯度消失或梯度爆炸的问题。

在这里插入图片描述

ResNet沿用了VGG完整的 3 × 3 3\times 3 3×3卷积层设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3卷积层,每个卷积层后接一个批量规范化层和ReLU激活函数,然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。如果想改变通道数,就需要引入一个额外的 1 × 1 1\times 1 1×1卷积层来将输入变换成需要的形状后再做相加运算。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as pltclass Residual(nn.Module):  #@savedef __init__(self, input_channels,num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

如下图所示,此代码生成两种类型的网络:当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出;当use_1x1conv=True时,使用 1 × 1 1 \times 1 1×1卷积调整通道和分辨率。

在这里插入图片描述

blk = Residual(3,3)#输入和输出形状一致
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print(Y.shape)blk = Residual(3,6, use_1x1conv=True, strides=2)#增加输出通道数的同时,减半输出的高和宽
print(blk(X).shape)#定义ResNet的模块
#b2-b5各有4个卷积层(不包括恒等映射的1x1卷积层),加上第一个7x7卷积层和最后一个全连接层,共有18层,因此这种模型通常被称为ResNet-18
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/254128.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序算法---冒泡排序

原创不易,转载请注明出处。欢迎点赞收藏~ 冒泡排序是一种简单的排序算法,其原理是重复地比较相邻的两个元素,并将顺序不正确的元素进行交换,使得每次遍历都能将一个最大(或最小)的元素放到末尾。通过多次遍…

疑似针对安全研究人员的窃密与勒索

前言 笔者在某国外开源样本沙箱平台闲逛的时候,发现了一个有趣的样本,该样本伪装成安全研究人员经常使用的某个渗透测试工具的破解版压缩包,对安全研究人员进行窃密与勒索双重攻击,这种双重攻击的方式也是勒索病毒黑客组织常用的…

RibbonOpenFeign源码(待完善)

Ribbon流程图 OpenFeign流程图

mac协议远程管理软件:Termius for Mac 8.4.0激活版

Termius是一款远程访问和管理工具,旨在帮助用户轻松地远程连接到各种服务器和设备。它适用于多种操作系统,包括Windows、macOS、Linux和移动设备。 该软件提供了一个直观的界面,使用户可以通过SSH、Telnet和Mosh等协议连接到远程设备。它还支…

【SpringBoot】JWT令牌

📝个人主页:五敷有你 🔥系列专栏:SpringBoot ⛺️稳重求进,晒太阳 什么是JWT JWT简称JSON Web Token,也就是通过JSON形式作为Web应用的令牌,用于各方面之间安全的将信息作为JSON对象传输…

本地部署TeamCity打包发布GitLab管理的.NET Framework 4.5.2的web项目

本地部署TeamCity 本地部署TeamCity打包发布GitLab管理的.NET Framework 4.5.2的web项目部署环境配置 TeamCity 服务器 URLTeamCity 上 GitLab 的相关配置GitLab 链接配置SSH 配置项目构建配置创建项目配置构建步骤构建触发器结语本地部署TeamCity打包发布GitLab管理的.NET Fra…

详细分析Redis性能监控指标 附参数解释(全)

目录 前言1. 基本指标2. 监控命令3. 实战演示 前言 对于Redis的相关知识推荐阅读: Redis框架从入门到学精(全)Python操作Redis从入门到精通附代码(全)Redis相关知识 1. 基本指标 Redis 是一个高性能的键值存储系统…

网络分析仪的防护技巧

VNA的一些使用防护技巧,虽不全面,但非常实用: [1] 一定要使用正规接地的三相交流电源线缆进行供电,地线不可悬浮,并且,火线和零线不可反接; [2] 交流供电必须稳定,如220V供电&#x…

【开源】SpringBoot框架开发桃花峪滑雪场租赁系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 游客服务2.2 雪场管理 三、数据库设计3.1 教练表3.2 教练聘请表3.3 押金规则表3.4 器材表3.5 滑雪场表3.7 售票表3.8 器材损坏表 四、系统展示五、核心代码5.1 查询教练5.2 教练聘请5.3 查询滑雪场5.4 滑雪场预定5.5 新…

LabVIEW动平衡测试与振动分析系统

LabVIEW动平衡测试与振动分析系统 介绍了利用LabVIEW软件和虚拟仪器技术开发一个动平衡测试与振动分析系统。该系统旨在提高旋转机械设备的测试精度和可靠性,通过精确测量和分析设备的振动数据,以识别和校正不平衡问题,从而保证机械设备的高…

Mac 使用AccessClient打开 windows 堡垒机的方式

使用AccessClient打开连接到 windows 页面 需要下载Microsoft remote Desktop 远程连接工具 在国内,无法下载正式版,beta 版本不需要从 app Store 下载 macOS 客户端下载地址 | Microsoft Learn 在浏览器点击对应的windows机器打开即可,会自动唤醒 Microsoft remote Desktop 进…

【MySQL】_JDBC编程

目录 1. JDBC原理 2. 导入JDBC驱动包 3. 编写JDBC代码实现Insert 3.1 创建并初始化一个数据源 3.2 和数据库服务器建立连接 3.3 构造SQL语句 3.4 执行SQL语句 3.5 释放必要的资源 4. JDBC代码的优化 4.1 从控制台输入 4.2 避免SQL注入的SQL语句 5. 编写JDBC代码实现…

HiveSQL——条件判断语句嵌套windows子句的应用

注:参考文章: SQL条件判断语句嵌套window子句的应用【易错点】--HiveSql面试题25_sql剁成嵌套判断-CSDN博客文章浏览阅读920次,点赞4次,收藏4次。0 需求分析需求:表如下user_idgood_namegoods_typerk1hadoop1011hive1…

OJ_计算不带括号的表达式

题干 C实现 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <stack> #include <string> #include <map> using namespace std;int main() {char str[1000] { 0 };map<char, int> priority {{\0,0},{,1},{-,1},{*,2},{/,2}};wh…

使用代理IP有风险吗?如何安全使用代理IP?

代理IP用途无处不在。它们允许您隐藏真实IP地址&#xff0c;从而实现匿名性和隐私保护。这对于保护个人信息、绕过地理受限的内容或访问特定网站都至关重要。 然而&#xff0c;正如任何技术工具一样&#xff0c;代理IP地址也伴随着潜在的风险和威胁。不法分子可能会滥用代理IP…

Golang 学习(二)进阶使用

二、进阶使用 性能提升——协程 GoRoutine go f();一个 Go 线程上&#xff0c;可以起多个协程&#xff08;有独立的栈空间、共享程序堆空间、调度由用户控制&#xff09;主线程是一个物理线程&#xff0c;直接作用在 cpu 上的。是重量级的&#xff0c;非常耗费 cpu 资源。协…

On the Spectral Bias of Neural Networks论文阅读

1. 摘要 众所周知&#xff0c;过度参数化的深度神经网络(DNNs)是一种表达能力极强的函数&#xff0c;它甚至可以以100%的训练精度记忆随机数据。这就提出了一个问题&#xff0c;为什么他们不能轻易地对真实数据进行拟合呢。为了回答这个问题&#xff0c;研究人员使用傅里叶分析…

基于BatchNorm的模型剪枝【详解+代码】

文章目录 1、BatchNorm&#xff08;BN&#xff09;2、L1与L2正则化2.1 L1与L2的导数及其应用2.2 论文核心点 3、模型剪枝的流程 ICCV经典论文&#xff0c;通俗易懂&#xff01;论文题目&#xff1a;Learning Efficient Convolutional Networks through Network Slimming卷积后能…

Linux系统安全之iptables防火墙

目录 一、iptables防火墙的基本介绍 二、iptables的四表五链 三、iptables的配置 四、添加&#xff0c;查看&#xff0c;删除规则 一、iptables防火墙的基本介绍 iptables是一个Linux系统上的防火墙工具&#xff0c;它用于配置和管理网络数据包的过滤规则。它可以通过定义…

STM32——LCD(1)认识

目录 一、初识LCD 1. LCD介绍 2. 显示器的分类 3. 像素 4. LED和OLED显示器 5. 显示器的基本参数 &#xff08;1&#xff09;像素 &#xff08;2&#xff09;分辨率 &#xff08;3&#xff09;色彩深度 &#xff08;4&#xff09;显示器尺寸 &#xff08;5&#xff…