每天一个知识点——Normalization

这里结合大模型的学习,主要分析Layer Norm、RMS Norm和Deep Norm的异同,与此同时,究竟是在之前执行Normalization(Pre-Norm)还是之后执行(Post-Norm),也是一个比较喜欢拿来讨论的知识点。

一、为什么要做Normalization?

        ICS问题出现的根本原因在于神经网络每层之间,无法满足基本假设"独立同分布"。深度神经网络涉及到很多层的叠加,而每一层的参数更新会导致上层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断去重新适应底层的参数更新。为了训好模型,我们需要非常谨慎地去设定学习率、初始化权重、以及尽可能细致的参数更新策略。

ICS( Internal Covariate Shift)问题导致的后果:

  • 上层参数需要不断适应新的输入数据分布,降低学习速度;
  • 下层输入的变化可能趋向于变大或者变小,导致上层落入饱和区,使得学习过早停止;
  • 每层的更新都会影响到其它层,因此每层的参数更新策略需要尽可能的谨慎;

二、Pre or Post?

Transformer的block模块里面包含了多处Normalization,具体结构如下: 

从图中可以看出,经过多头(multi-head)的self-attention之后结合残差网络,再做一次归一化。简化一下,如下图所示,

Pre-Norm的公式为:

x_{t+1}=x_t+F_t(Norm(x_t)) 

Post-Norm的公式为:

x_{t+1}=Norm(x_t+F_t(x_t))

问题一、为什么Pre Norm的效果不如Post Norm?Pre Norm的深度有“水分”!也就是说,一个L层的Pre Norm模型,其实际等效层数不如L层的Post Norm模型,而层数少了导致效果变差了?

原因:

\begin{aligned} x_{t+1}&=x_t+F_t(Norm(x_t)) \\ &=x_{t-1} + F_{t-1}(Norm(x_{t-1})) + F_t(Norm(x_t)) \\ &=... \\ &=x_0 + F_0(Norm(x_0)) + ... + F_{t-1}(Norm(x_{t-1})) + F_t(Norm(x_t)) \end{aligned}

所以在Pre Norm中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”。说白了,Pre Norm结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了!

post-norm和pre-norm其实各有优势,post-norm在残差之后做归一化,对参数正则化的效果更强,进而模型的鲁棒性也会更好pre-norm相对于post-norm,因为有一部分参数直接加在了后面,不需要对这部分参数进行正则化,正好可以防止模型的梯度爆炸或者梯度消失,因此,这里笔者可以得出的一个结论是如果层数少post-norm的效果其实要好一些,如果要把层数加大,为了保证模型的训练,pre-norm显然更好一些。

问题二:为什么Layer Normalization要加在F的前面,而不是F的后面呢?

因为做完Layer Normalization之后的数据不能和平常的数据加在一起,如果这样做的话残差中从上一层出来的信息会占很大比重,这显然并不合理。

三、BatchNorm or LayerNorm?

这个问题争论了很久,也是面试官比较喜欢问的一个问题,貌似有点不死不休的意思。到底哪个好一定是有个定论吗?还是说不同场景下的选择?抑或是仅是理论上的讨论,在实操上并没有明显的性能差异?

首先列举一些网络中的解释:

        1. BatchNorm适用于CV,而LayerNorm适用于NLP,这是由两个任务的本质差异决定的,视觉的特征是客观存在的特征,而语义特征更多是由上下文语义决定的一种统计特征,因此他们的标准化方法也会有所不同。

        2. layernorm更容易并行训练。当每个device用不同的minbatch训练,我们需要额外地同步各个device上的batchnorm,用layernorm则不需要。

        3. layernorm所带来的hidden layer分布上的稳定性,促进了更平滑的梯度,更快的训练速度,更好的模型泛化能力等等。

        4. LayerNorm是后起之秀,挑战前辈BatchNorm必然是解决它的某些痛点或者劣势,那么BatchNorm有哪些痛点呢?

  • batch非常小,比如训练资源有限无法应用较大的batch,也比如在线学习等使用单例进行模型参数更新的场景
  • 对于rnn等动态的网络结构,同一个batch中训练实例有长有短,导致每一个时间步长必须维持各自的统计量,这使得BN并不能正确的使用。

        5. 约定俗称,比如NLP几个经典模型都是用的LayerNorm,之后很多的追随者也沿用了这一套。

另外需要注意的一个点是由于现在训练模型都是采用的mini-batch的方式,所以在batchnorm在估计训练集(推理阶段)整体的均值方差时,常采用EMA( exponential moving average),指数移动平均来估计:

\mu _{EMA}\leftarrow \lambda \mu _{EMA}+(1-\lambda )\mu _{B}

\sigma _{EMA}^{2}\leftarrow \lambda \sigma _{EMA}^{2}+(1-\lambda )\sigma _{B}^{2}

其中\mu _{B}\sigma _{B}^{2}为训练阶段的均值和方差,

\mu _{B}=X.mean(dim=0, keepdim=True)

\sigma _{B}^{2}=X.var(dim=0, keepdim=True)

然后对特征x进行归一化,

\hat{x}=\frac{x-\mu }{\sqrt{\sigma ^{2}+\epsilon }}

除了归一化,BatchNorm还包括对各个channel的特征做affine tranform(增加特征表达能力,有待考证):

y=\alpha \hat{x}+\beta

\alpha\beta都是可训练的参数。

四、RMS Norm和Deep Norm

RMS:Root Mean Square Layer Normalization

与layerNorm相比,RMS Norm的主要区别在于去掉了减去均值的部分,计算公式为:

\bar{a}_{i}=\frac{a_{i}}{RMS(a)}g_{i},  其中RMS(a)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}a_{i}^{2}}

1、可以在梯度下降时令损失更加平滑?(有待考证)?

2、可以在各个模型上减少约 7%∼64% 的计算时间。

Deep Norm是对Post-LN的的改进,代码如下,

e135fbad38253919370571c9587b6152.png

作者认为 Post-LN 的不稳定性部分来自于梯度消失以及太大的模型更新,而DeepNorm可以缓解这个问题

  • DeepNorm在进行Layer Norm之前会以 \alpha参数扩大残差连接
  • 在Xavier参数初始化过程中以 \beta减小部分参数的初始化范围

说一千道一万,对于没法做到理论功底那么深厚的人来说,效果为王,谁效果好就用谁

a5c81cce8bead488474cbc1b0d2c5ac0.png

五、参考资料 

  1. DEEPNORM:千层transformer...

  2. 昇腾大模型|结构组件-1——Layer Norm、RMS Norm、Deep Norm

  3. BatchNorm与LayerNorm

  4. BatchNorm与LayerNorm的理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83518.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编织人工智能:机器学习发展历史与关键技术全解析

文章目录 1. 引言1.1 机器学习的定义1.2 重要性和应用场景重要性应用场景 2. 机器学习的早期历史2.1 初期理论与算法感知机决策树 2.2 早期突破支持向量机神经网络初探 3. 21世纪初期的发展3.1 集成学习方法随机森林XGBoost 3.2 深度学习的崛起卷积神经网络(CNN&…

Sketch打不开AI文件?转换方法在这里

1、对比设计软件 Sketch 与 AI 软件功能 Sketch 与 Illustrator 都是行业内优秀的矢量图形设计软件,各有千秋。Sketch 从 2010 年面世,专注 APP 界面设计,深受初学者与专业人士喜爱。Illustrator 拥有更悠久的历史,是处理复杂图标…

问道管理:A股缩量整理 新股上演久违暴涨模式

周三,大盘低开后震动,三大指数小幅跌落,创业板指相对偏强。 早盘开盘后,沪指、深证成指弱势震动,创业板指探底上升翻红,盘面热门乏善可陈。午后三大指数震动走弱,创业板指再度翻绿。医药板块活…

基于百度语音识别API智能语音识别和字幕推荐系统——深度学习算法应用(含全部工程源码)+测试数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 翻译3. 格式转换4. 音频切割5. 语音识别6. 文本切割7. main函数 系统测试工程源代码下载其它资料下载 前言 本项目基于百度语音识别API,结合了语音识别、视频转换音频识别以及语句停顿…

性能测试jmeter连接数据库jdbc(sql server举例)

一、下载第三方工具包驱动数据库 1. 因为JMeter本身没有提供链接数据库的功能,所以我们需要借助第三方的工具包来实现。 (有这个jar包之后,jmeter可以发起jdbc请求,没有这个jar包,也有jdbc取样器,但不能发起…

大模型的数据隐私问题有解了,浙江大学提出联邦大语言模型

作者 | 小戏、Python 理想化的 Learning 的理论方法作用于现实世界总会面临着诸多挑战,从模型部署到模型压缩,从数据的可获取性到数据的隐私问题。而面对着公共领域数据的稀缺性以及私有领域的数据隐私问题,联邦学习(Federated Le…

ACM算法竞赛中在编辑器中使用输入输出样例-CPH

通用方法 我们可以在编辑器中创建三个文件,一个是main.cpp,一个是test.in,一个是test.out分别用来写代码,输入输入数据,显示输出数据 这种方法的好处是不需要插件,在任何编辑器中都可以实现,例如Devc,sublime,vscode…

HBase-读流程

创建连接同写流程。 (1)读取本地缓存中的Meta表信息;(第一次启动客户端为空) (2)向ZK发起读取Meta表所在位置的请求; (3)ZK正常返回Meta表所在位置&#x…

SQL常见命令语句

1.连接数据库 mysql (-h IP) -u root -p 密码2.查看数据库 show databases3.使用数据库 use db_name4.查看表 show tables [from db_name]5.查看表结构 desc tb_name6.创建、删除、选择数据库 create database db_namedrop database db_nameuse db_name7.数据类型 参考链…

Python-OpenCV中的图像处理-几何变换

Python-OpenCV中的图像处理-几何变换 几何变换图像缩放图像平移图像旋转仿射变换透视变换 几何变换 对图像进行各种几个变换,例如移动,旋转,仿射变换等。 图像缩放 cv2.resize() cv2.INTER_AREAv2.INTER_CUBICv2.INTER_LINEAR res cv2.r…

优化理论 | Time-Sharing Condition

版权声明 原创作品,整理不易,转载请标明出处。本篇推送更详细的内容介绍,可参见本人微信公众号“优化与博弈的数学原理”,公众号二维码参见文末。 编者按 OFDM系统中的功率分配问题是通信领域中的研究热点。本文重点考虑了面向…

7.5 详解批量规范化 对某个维度取平均值代码解读

一.举例计算均值、方差 假设我们有以下一组数据:[10, 15, 20, 25, 30]首先,我们计算均值,即将所有数据相加后除以数据的数量: **均值** (10 15 20 25 30) / 5 100 / 5 201.1标准差 接下来,我们计算标准差&…

笔记本电脑如何把sd卡数据恢复

在使用笔记本电脑过程中,如果不小心将SD卡里面的重要数据弄丢怎么办呢?别着急,本文将向您介绍SD卡数据丢失常见原因和恢复方法。 ▌一、SD卡数据丢失常见原因 - 意外删除:误操作或不小心将文件或文件夹删除。 - 误格式化&#…

跨境干货|TikTok变现的9种方法

在这个流量为王的时代,哪里有流量,哪里就有商机。TikTok作为近几年最火爆的社媒平台之一,在全球范围都具有一定的影响力。随着TikTok Shop等商务功能加持上线,更是称为跨境电商的新主场之一。 在这样的UGC平台,想要变…

适配器模式-java实现

意图 复用已经存在的接口,与所需接口不一致的类。即将一个类(通常是旧系统中的功能类),通过适配器转化成另一个接口的实现。(简单来说,就是复用旧系统的功能,去实现新的接口) 我们举…

API 测试 | 了解 API 接口概念|电商平台 API 接口测试指南

什么是 API? API 是一个缩写,它代表了一个 pplication P AGC 软件覆盖整个房间。API 是用于构建软件应用程序的一组例程,协议和工具。API 指定一个软件程序应如何与其他软件程序进行交互。 例行程序:执行特定任务的程序。例程也称…

深度学习:使用卷积神经网络CNN实现MNIST手写数字识别

引言 本项目基于pytorch构建了一个深度学习神经网络,网络包含卷积层、池化层、全连接层,通过此网络实现对MINST数据集手写数字的识别,通过本项目代码,从原理上理解手写数字识别的全过程,包括反向传播,梯度…

【UE4 RTS】04-Camera Pan

前言 本篇实现了CameraPawn的旋转功能。 效果 步骤 1. 打开项目设置,添加两个操作映射 2. 打开玩家控制器“RTS_PlayerController_BP”,新建一个浮点型变量,命名为“PanSpeed” 在事件图表中添加如下节点 此时运行游戏可以发现当鼠标移动…

ReSharper C++ 2023 Crack

ReSharper C 2023 Crack ReSharper的AI助手会考虑项目中使用的语言和技术。这种上下文感知可以一开始就调整其响应,为您节省时间和精力。 您可以在查询中包含部分源代码。ReSharper将检测你发送或粘贴到聊天中的代码,并正确格式化,而人工智能…

【数据结构OJ题】合并两个有序数组

原题链接:https://leetcode.cn/problems/merge-sorted-array/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 看到这道题,我们注意到nums1[ ]和nums2[ ]两个数组都是非递减的。所以我们很容易想到额外开一个数组tmp[ ]&#x…