【深度学习基础】多层感知机 | 数值稳定性和模型初始化

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、梯度消失和梯度爆炸
      • (一)梯度消失
      • (二)梯度爆炸
      • (三)打破对称性
    • 二、参数初始化
      • (一)默认初始化
      • (二)Xavier初始化
      • (三)拓展阅读
    • 小结


  到目前为止,我们实现的每个模型都是根据某个预先指定的分布来初始化模型的参数。有人会认为初始化方案是理所当然的,忽略了如何做出这些选择的细节。甚至有人可能会觉得,初始化方案的选择并不是特别重要。相反,初始化方案的选择在神经网络学习中起着举足轻重的作用,它对保持数值稳定性至关重要。此外,这些初始化方案的选择可以与非线性激活函数的选择有趣的结合在一起。我们选择哪个函数以及如何初始化参数可以决定优化算法收敛的速度有多快。糟糕的选择可能会导致我们在训练时遇到梯度爆炸或梯度消失。本节将更详细地探讨这些主题,并讨论一些有用的启发式方法。这些启发式方法在整个深度学习生涯中都很有用。

一、梯度消失和梯度爆炸

  考虑一个具有 L L L层、输入 x \mathbf{x} x和输出 o \mathbf{o} o的深层网络。每一层 l l l由变换 f l f_l fl定义,该变换的参数为权重 W ( l ) \mathbf{W}^{(l)} W(l),其隐藏变量是 h ( l ) \mathbf{h}^{(l)} h(l)(令 h ( 0 ) = x \mathbf{h}^{(0)} = \mathbf{x} h(0)=x)。我们的网络可以表示为:
h ( l ) = f l ( h ( l − 1 ) ) 因此  o = f L ∘ … ∘ f 1 ( x ) (1) \mathbf{h}^{(l)} = f_l (\mathbf{h}^{(l-1)}) \text{ 因此 } \mathbf{o} = f_L \circ \ldots \circ f_1(\mathbf{x}) \tag{1} h(l)=fl(h(l1)) 因此 o=fLf1(x)(1)

  如果所有隐藏变量和输入都是向量,我们可以将 o \mathbf{o} o关于任何一组参数 W ( l ) \mathbf{W}^{(l)} W(l)的梯度写为下式:
∂ W ( l ) o = ∂ h ( L − 1 ) h ( L ) ⏟ M ( L ) = d e f ⋅ … ⋅ ∂ h ( l ) h ( l + 1 ) ⏟ M ( l + 1 ) = d e f ∂ W ( l ) h ( l ) ⏟ v ( l ) = d e f (2) \partial_{\mathbf{W}^{(l)}} \mathbf{o} = \underbrace{\partial_{\mathbf{h}^{(L-1)}} \mathbf{h}^{(L)}}_{ \mathbf{M}^{(L)} \stackrel{\mathrm{def}}{=}} \cdot \ldots \cdot \underbrace{\partial_{\mathbf{h}^{(l)}} \mathbf{h}^{(l+1)}}_{ \mathbf{M}^{(l+1)} \stackrel{\mathrm{def}}{=}} \underbrace{\partial_{\mathbf{W}^{(l)}} \mathbf{h}^{(l)}}_{ \mathbf{v}^{(l)} \stackrel{\mathrm{def}}{=}} \tag{2} W(l)o=M(L)=def h(L1)h(L)M(l+1)=def h(l)h(l+1)v(l)=def W(l)h(l)(2)

  换言之,该梯度是 L − l L-l Ll个矩阵 M ( L ) ⋅ … ⋅ M ( l + 1 ) \mathbf{M}^{(L)} \cdot \ldots \cdot \mathbf{M}^{(l+1)} M(L)M(l+1)与梯度向量 v ( l ) \mathbf{v}^{(l)} v(l)的乘积。因此,我们容易受到数值下溢问题的影响。当将太多的概率乘在一起时,这些问题经常会出现。在处理概率时,一个常见的技巧是切换到对数空间,即将数值表示的压力从尾数转移到指数。不幸的是,上面的问题更为严重:最初,矩阵 M ( l ) \mathbf{M}^{(l)} M(l) 可能具有各种各样的特征值。他们可能很小,也可能很大;他们的乘积可能非常大,也可能非常小。

  不稳定梯度带来的风险不止在于数值表示;不稳定梯度也威胁到我们优化算法的稳定性。我们可能面临一些问题。要么是梯度爆炸(gradient exploding)问题:参数更新过大,破坏了模型的稳定收敛;要么是梯度消失(gradient vanishing)问题:参数更新过小,在每次更新时几乎不会移动,导致模型无法学习。

(一)梯度消失

  曾经sigmoid函数 1 / ( 1 + exp ⁡ ( − x ) ) 1/(1 + \exp(-x)) 1/(1+exp(x))(多层感知机概述提到过)很流行,因为它类似于阈值函数。由于早期的人工神经网络受到生物神经网络的启发,神经元要么完全激活要么完全不激活(就像生物神经元)的想法很有吸引力。然而,它却是导致梯度消失问题的一个常见的原因,让我们仔细看看sigmoid函数为什么会导致梯度消失。

%matplotlib inline
import torch
from d2l import torch as d2lx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
y = torch.sigmoid(x)
y.backward(torch.ones_like(x))d2l.plot(x.detach().numpy(), [y.detach().numpy(), x.grad.numpy()], legend=['sigmoid', 'gradient'], figsize=(4.5, 2.5))

在这里插入图片描述

  正如上图,当sigmoid函数的输入很大或是很小时,它的梯度都会消失。此外,当反向传播通过许多层时,除非我们在刚刚好的地方,这些地方sigmoid函数的输入接近于零,否则整个乘积的梯度可能会消失。当我们的网络有很多层时,除非我们很小心,否则在某一层可能会切断梯度。事实上,这个问题曾经困扰着深度网络的训练。因此,更稳定的ReLU系列函数已经成为从业者的默认选择(虽然在神经科学的角度看起来不太合理)。

(二)梯度爆炸

  相反,梯度爆炸可能同样令人烦恼。为了更好地说明这一点,我们生成100个高斯随机矩阵,并将它们与某个初始矩阵相乘。对于我们选择的尺度(方差 σ 2 = 1 \sigma^2=1 σ2=1),矩阵乘积发生爆炸。当这种情况是由于深度网络的初始化所导致时,我们没有机会让梯度下降优化器收敛。

M = torch.normal(0, 1, size=(4,4))
print('一个矩阵 \n',M)
for i in range(100):M = torch.mm(M,torch.normal(0, 1, size=(4, 4)))print('乘以100个矩阵后\n', M)

在这里插入图片描述

(三)打破对称性

  神经网络设计中的另一个问题是其参数化所固有的对称性。假设我们有一个简单的多层感知机,它有一个隐藏层和两个隐藏单元。在这种情况下,我们可以对第一层的权重 W ( 1 ) \mathbf{W}^{(1)} W(1)进行重排列,并且同样对输出层的权重进行重排列,可以获得相同的函数。第一个隐藏单元与第二个隐藏单元没有什么特别的区别。换句话说,我们在每一层的隐藏单元之间具有排列对称性。

  假设输出层将上述两个隐藏单元的多层感知机转换为仅一个输出单元。想象一下,如果我们将隐藏层的所有参数初始化为 W ( 1 ) = c \mathbf{W}^{(1)} = c W(1)=c c c c为常量,会发生什么?在这种情况下,在前向传播期间,两个隐藏单元采用相同的输入和参数,产生相同的激活,该激活被送到输出单元。在反向传播期间,根据参数 W ( 1 ) \mathbf{W}^{(1)} W(1)对输出单元进行微分,得到一个梯度,其元素都取相同的值。因此,在基于梯度的迭代(例如,小批量随机梯度下降)之后, W ( 1 ) \mathbf{W}^{(1)} W(1)的所有元素仍然采用相同的值。这样的迭代永远不会打破对称性,我们可能永远也无法实现网络的表达能力。隐藏层的行为就好像只有一个单元。请注意,虽然小批量随机梯度下降不会打破这种对称性,但暂退法正则化可以。

二、参数初始化

  解决(或至少减轻)上述问题的一种方法是进行参数初始化,优化期间的注意和适当的正则化也可以进一步提高稳定性。

(一)默认初始化

  在前面的部分中,例如在线性回归的简洁实现中,我们使用正态分布来初始化权重值。如果我们不指定初始化方法,框架将使用默认的随机初始化方法,对于中等难度的问题,这种方法通常很有效。

(二)Xavier初始化

  让我们看看某些没有非线性的全连接层输出(例如,隐藏变量) o i o_{i} oi的尺度分布。对于该层 n i n n_\mathrm{in} nin输入 x j x_j xj及其相关权重 w i j w_{ij} wij,输出由下式给出
o i = ∑ j = 1 n i n w i j x j (3) o_{i} = \sum_{j=1}^{n_\mathrm{in}} w_{ij} x_j \tag{3} oi=j=1ninwijxj(3)

  权重 w i j w_{ij} wij都是从同一分布中独立抽取的。此外,让我们假设该分布具有零均值和方差 σ 2 \sigma^2 σ2。请注意,这并不意味着分布必须是高斯的,只是均值和方差需要存在。现在,让我们假设层 x j x_j xj的输入也具有零均值和方差 γ 2 \gamma^2 γ2,并且它们独立于 w i j w_{ij} wij并且彼此独立。在这种情况下,我们可以按如下方式计算 o i o_i oi的平均值和方差:
E [ o i ] = ∑ j = 1 n i n E [ w i j x j ] = ∑ j = 1 n i n E [ w i j ] E [ x j ] = 0 V a r [ o i ] = E [ o i 2 ] − ( E [ o i ] ) 2 = ∑ j = 1 n i n E [ w i j 2 x j 2 ] − 0 = ∑ j = 1 n i n E [ w i j 2 ] E [ x j 2 ] = n i n σ 2 γ 2 (4) \begin{aligned} E[o_i] & = \sum_{j=1}^{n_\mathrm{in}} E[w_{ij} x_j] \\&= \sum_{j=1}^{n_\mathrm{in}} E[w_{ij}] E[x_j] \\&= 0 \\[1ex] \mathrm{Var}[o_i] & = E[o_i^2] - (E[o_i])^2 \\ & = \sum_{j=1}^{n_\mathrm{in}} E[w^2_{ij} x^2_j] - 0 \\ & = \sum_{j=1}^{n_\mathrm{in}} E[w^2_{ij}] E[x^2_j] \\ & = n_\mathrm{in} \sigma^2 \gamma^2 \end{aligned} \tag{4} E[oi]Var[oi]=j=1ninE[wijxj]=j=1ninE[wij]E[xj]=0=E[oi2](E[oi])2=j=1ninE[wij2xj2]0=j=1ninE[wij2]E[xj2]=ninσ2γ2(4)

  保持方差不变的一种方法是设置 n i n σ 2 = 1 n_\mathrm{in} \sigma^2 = 1 ninσ2=1。现在考虑反向传播过程,我们面临着类似的问题,尽管梯度是从更靠近输出的层传播的。使用与前向传播相同的推断,我们可以看到,除非 n o u t σ 2 = 1 n_\mathrm{out} \sigma^2 = 1 noutσ2=1,否则梯度的方差可能会增大,其中 n o u t n_\mathrm{out} nout是该层的输出的数量。这使得我们进退两难:我们不可能同时满足这两个条件。相反,我们只需满足:
1 2 ( n i n + n o u t ) σ 2 = 1 或等价于  σ = 2 n i n + n o u t . (5) \begin{aligned} \frac{1}{2} (n_\mathrm{in} + n_\mathrm{out}) \sigma^2 = 1 \text{ 或等价于 } \sigma = \sqrt{\frac{2}{n_\mathrm{in} + n_\mathrm{out}}}. \end{aligned} \tag{5} 21(nin+nout)σ2=1 或等价于 σ=nin+nout2 .(5)

  这就是现在标准且实用的Xavier初始化的基础,它以其提出者第一作者的名字命名。通常,Xavier初始化从均值为零,方差 σ 2 = 2 n i n + n o u t \sigma^2 = \frac{2}{n_\mathrm{in} + n_\mathrm{out}} σ2=nin+nout2的高斯分布中采样权重。我们也可以将其改为选择从均匀分布中抽取权重时的方差。注意均匀分布 U ( − a , a ) U(-a, a) U(a,a)的方差为 a 2 3 \frac{a^2}{3} 3a2。将 a 2 3 \frac{a^2}{3} 3a2代入到 σ 2 \sigma^2 σ2的条件中,将得到初始化值域:
U ( − 6 n i n + n o u t , 6 n i n + n o u t ) (6) U\left(-\sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}, \sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}\right) \tag{6} U(nin+nout6 ,nin+nout6 )(6)

  尽管在上述数学推理中,“不存在非线性”的假设在神经网络中很容易被违反,但Xavier初始化方法在实践中被证明是有效的。

(三)拓展阅读

  上面的推理仅仅触及了现代参数初始化方法的皮毛。深度学习框架通常实现十几种不同的启发式方法。此外,参数初始化一直是深度学习基础研究的热点领域。其中包括专门用于参数绑定(共享)、超分辨率、序列模型和其他情况的启发式算法。例如,Xiao等人演示了通过使用精心设计的初始化方法,可以无须架构上的技巧而训练10000层神经网络的可能性。

  如果有读者对该主题感兴趣,我们建议深入研究本模块的内容,阅读提出并分析每种启发式方法的论文,然后探索有关该主题的最新出版物。也许会偶然发现甚至发明一个聪明的想法,并为深度学习框架提供一个实现。

小结

  • 梯度消失和梯度爆炸是深度网络中常见的问题。在参数初始化时需要非常小心,以确保梯度和参数可以得到很好的控制。
  • 需要用启发式的初始化方法来确保初始梯度既不太大也不太小。
  • ReLU激活函数缓解了梯度消失问题,这样可以加速收敛。
  • 随机初始化是保证在进行优化前打破对称性的关键。
  • Xavier初始化表明,对于每一层,输出的方差不受输入数量的影响,任何梯度的方差不受输出数量的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/6926.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字图像处理:实验五

uu们!大家好,欢迎来到数字图像处理第五章节内容的学习,在本章中有关空间滤波的理论学习是十分重要的,所以建议大家要去用心的学习本章,在之后的传感器的相关图像采集时,不可避免的会有噪声等的影响&#xf…

CCF开源发展委员会开源供应链安全工作组2025年第1期技术研讨会顺利举行

点击蓝字 关注我们 CCF Opensource Development Committee 2025年1月17日,CCF开源发展委员会供应链安全工作组(CCF-ODC-OSS)2025年第一期技术研讨会——“大模型时代的开源供应链安全风控技术”于北京黄大年茶思屋顺利举行。本次研讨会邀请了…

如何进行市场调研?海外问卷调查有哪些类型和示例?

什么是市场研究? 市场研究的目的,就是调查消费者的行为和当时的经济趋势,帮助企业制定和调整经营理念和经营路线,通过收集和分析数据,帮助企业了解其目标市场。 市场调查是通过对潜在客户的分析,来判断品…

DX12 快速教程(4) —— 画钻石原矿

快速导航 新建项目 "004-DrawTexture"纹理贴图纹理采样纹理过滤邻近点采样双线性过滤Mipmap 多级渐远纹理三线性过滤各向异性过滤 纹理环绕LOD 细节层次 开始画钻石原矿吧加载纹理到内存中:LoadTexture什么是 WIC如何用 WIC 读取一帧图片获取图片格式并转…

FPGA实现任意角度视频旋转(二)视频90度/270度无裁剪旋转

本文主要介绍如何基于FPGA实现视频的90度/270度无裁剪旋转,旋转效果示意图如下: 为了实时对比旋转效果,采用分屏显示进行处理,左边代表旋转前的视频在屏幕中的位置,右边代表旋转后的视频在屏幕中的位置。 分屏显示的…

Blazor-选择循环语句

今天我们来说说Blazor选择语句和循环语句。 下面我们以一个简单的例子来讲解相关的语法,我已经创建好了一个Student类,以此类来进行语法的运用 因为我们需要交互性所以我们将类创建在*.client目录下 if 我们做一个学生信息的显示,Gender为…

数据结构——实验八·学生管理系统

嗨~~欢迎来到Tubishu的博客🌸如果你也是一名在校大学生,正在寻找各种编程资源,那么你就来对地方啦🌟 Tubishu是一名计算机本科生,会不定期整理和分享学习中的优质资源,希望能为你的编程之路添砖加瓦⭐&…

在 Ubuntu22.04 上安装 Splunk

ELK感觉太麻烦了,换个日志收集工具 Splunk 是一种 IT 工具,可帮助在任何设备上收集日志、分析、可视化、审计和创建报告。简单来说,它将“机器生成的数据转换为人类可读的数据”。它支持从虚拟机、网络设备、防火墙、基于 Unix 和基于 Windo…

【C++高并发服务器WebServer】-2:exec函数簇、进程控制

本文目录 一、exec函数簇介绍二、exec函数簇 一、exec函数簇介绍 exec 函数族的作用是根据指定的文件名找到可执行文件,并用它来取代调用进程的内容,换句话说,就是在调用进程内部执行一个可执行文件。 exec函数族的函数执行成功后不会返回&…

[ACTF2020 新生赛]Upload1

题目 以为是前端验证&#xff0c;试了一下PHP传不上去 可以创建一个1.phtml文件。对.phtml文件的解释: 是一个嵌入了PHP脚本的html页面。将以下代码写入该文件中 <script languagephp>eval($_POST[md]);</script><script languagephp>system(cat /flag);&l…

第24篇 基于ARM A9处理器用汇编语言实现中断<六>

Q&#xff1a;怎样设计ARM处理器汇编语言程序使用定时器中断实现实时时钟&#xff1f; A&#xff1a;此前我们曾使用轮询定时器I/O的方式实现实时时钟&#xff0c;而在本实验中将采用定时器中断的方式。新增第三个中断源A9 Private Timer&#xff0c;对该定时器进行配置&#…

SpringMVC新版本踩坑[已解决]

问题&#xff1a; 在使用最新版本springMVC做项目部署时&#xff0c;浏览器反复500&#xff0c;如下图&#xff1a; 异常描述&#xff1a; 类型异常报告 消息Request processing failed: java.lang.IllegalArgumentException: Name for argument of type [int] not specifie…

系统思考—复杂问题的根源分析

在企业中&#xff0c;许多问题看似简单&#xff0c;背后却潜藏着复杂的因果关系。传统的思维方式往往只能看到表面&#xff0c;而无法深入挖掘问题的真正根源。我们常常通过“表面解决”来应对眼前的症状&#xff0c;但这往往只是治标不治本。 比如&#xff0c;销量下降时&…

安装VMware17

一、VMware Workstation 简介 VMware Workstation是一款由VMware公司开发的功能强大的桌面虚拟化软件。它允许用户在单一的物理电脑上同时运行多个操作系统作为虚拟机&#xff08;VMs&#xff09;&#xff0c;每个虚拟机都可配置有自己的独立硬件资源&#xff0c;如CPU核心、内…

三、双链表

链表的种类有很多&#xff0c;单链表是不带头不循环单向链表&#xff0c;但双链表是带头循环双向链表&#xff0c;并且双链表还有一个哨兵位&#xff0c;哨兵位不是头节点 typedef int LTDataType;typedef struct ListNode{struct ListNode* next; //指针保存下⼀个结点的地址s…

【知识】可视化理解git中的cherry-pick、merge、rebase

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 这三个确实非常像&#xff0c;以至于对于初学者来说比较难理解。 总结对比 先给出对比&#xff1a; 特性git mergegit rebasegit cherry-pick功能合并…

SpringBoot开发(三)SpringBoot介绍、项目创建、运行

1. SpringBoot 1.1. SpringBoot介绍 Spring Boot给世界程序员带来了春天&#xff0c;越来越多的企业选择使用spring boot来开发他们的软件&#xff0c;因此学习spring boot是科技发展的必然趋势。本门课程将从web最基础的知识点开始讲起&#xff0c;逐步带你攻破spring boot的…

438. 找到字符串中所有字母异位词

【题目】&#xff1a;438. 找到字符串中所有字母异位词 class Solution { public:vector<int> findAnagrams(string s, string p) {vector<int> res;vector<int> curVec(26, 0); // 统计p中字母出现的次数for(char c : p) {curVec[c - a];}for(int l 0, r …

Leetcode-两数之和

1.暴力枚举 class Solution { public:vector<int> twoSum(vector<int>& nums, int target) {int lennums.size();int i,j;for(i0;i<len;i){for(ji1;j<len;j){if(nums[i]nums[j]target){return{i,j};}}}return {i,j};} }; 新知识&#xff1a; return {…

边缘网关具备哪些功能?

边缘网关&#xff0c;又称边缘计算网关&#xff0c;部署在网络边缘&#xff0c;它位于物联网设备与云计算平台之间&#xff0c;充当着数据流动的“守门员”和“处理器”。通过其强大的数据处理能力和多样化的通信协议支持&#xff0c;边缘网关能够实时分析、过滤和存储来自终端…