Pytorch中layernorm实现详解

      平时我们在编写神经网络时,经常会用到layernorm这个函数来加快网络的收敛速度。那layernorm到底在哪个维度上进行归一化的呢? 

一、问题描述

    首先借用知乎上的一张图,原文写的也非常好,大家有空可以去阅读一下,链接放在参考文献里了。如左图所示,假设现在输入的维度是(bs,seq_len, embedding),其中bs代表batch_size, seq_len代表序列长度 ,embedding表示嵌入大小。

    那在layernorm时,我们是对(seq_len, embedding)这个矩阵取均值和方差(上图);还是只对embedding这个维度取均值和方差呢(下图)?前者会得到bs个均值和方差,而后者会得到bs * seq_len 个均值和方差。下面我们进行编程验证。

二、编程实现

import torchbatch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("用pytorch的layer_norm所得结果\n", layer_norm(embedding))print("自己编写layer_norm所得结果")
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

结果:

用pytorch的layer_norm所得结果tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],[ 0.1144, -0.6476,  1.5753, -1.0421],[-1.0278, -0.7498,  0.2559,  1.5218]],[[-1.0527, -0.8723,  1.3354,  0.5895],[-0.6403, -1.1399,  1.4842,  0.2961],[ 0.7352, -0.8236, -1.1342,  1.2226]]])
自己编写layer_norm所得结果
mean:  torch.Size([2, 3, 1])
y_custom:  tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],[ 0.1144, -0.6476,  1.5753, -1.0421],[-1.0278, -0.7498,  0.2559,  1.5218]],[[-1.0527, -0.8723,  1.3354,  0.5895],[-0.6403, -1.1399,  1.4842,  0.2961],[ 0.7352, -0.8236, -1.1342,  1.2226]]])

结果的相等的。可以看到,我们在取均值和方差时,是对最后一个维度取的。所以我们会得到 (N,C)个均值与方差。假设二是正确的。 

而实际上这种实现方法和Instance Norm是相同的

from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(2, 3, 4)
output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入
print(output.reshape(2,3,4))layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))

结果:

tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],[ 0.1144, -0.6476,  1.5753, -1.0421],[-1.0278, -0.7498,  0.2559,  1.5218]],[[-1.0527, -0.8723,  1.3354,  0.5895],[-0.6403, -1.1399,  1.4842,  0.2961],[ 0.7352, -0.8236, -1.1342,  1.2226]]])
tensor([[[ 0.1293, -1.0034,  1.5760, -0.7018],[-1.3981, -0.4828,  1.0876,  0.7933],[-1.7034,  0.8545,  0.4876,  0.3612]],[[-1.4750,  1.2212, -0.2607,  0.5144],[ 0.7017, -0.8350,  1.2502, -1.1169],[-1.7273,  0.6965,  0.5147,  0.5161]]])

三、参考文献

(45 封私信 / 80 条消息) 为什么Transformer要用LayerNorm? - 知乎 (zhihu.com)https://www.zhihu.com/question/487766088/answer/2644783144

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/37555.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

六十天前端强化训练之第二十五天之组件生命周期大师级详解(Vue3 Composition API 版)

欢迎来到编程星辰海的博客讲解 看完可以给一个免费的三连吗,谢谢大佬! 目录 一、生命周期核心知识 1.1 生命周期全景图 1.2 生命周期钩子详解 1.2.1 初始化阶段 1.2.2 挂载阶段 1.2.3 更新阶段 1.2.4 卸载阶段 1.3 生命周期执行顺序 1.4 父子组…

Burp Suite 代理配置与网络通信

目录 1. 引言 2. Burp 代理基础配置 2.1 浏览器代理设置 2.2 Burp 监听端口配置 2.3 常见错误排查 3. 网络问题解决 3.1 端口占用检查 3.2 防火墙配置 3.3 证书信任问题 4. 虚拟机环境配置 4.1 NAT 模式与端口转发 4.2 桥接模式配置 4.3 跨设备访问测试 5. 技术概…

numpy学习笔记16: 1000 次独立随机游走实验(绘制其分布直方图,同时叠加理论正态分布曲线)

numpy学习笔记16: 1000 次独立随机游走实验(绘制其分布直方图,同时叠加理论正态分布曲线) 以下是这段代码(全部代码在最后)的详细分步解释,结合统计学原理和可视化技巧: 1. 代码功能概述 这段代码通过 1000 次独立随机游走实验&…

C# 项目06-计算程序运行时间

实现需求 记录程序运行时间,当程序退出后,保存程序运行时间,等下次程序再次启动时,继续记录运行时间 运行环境 Visual Studio 2022 知识点 TimeSpan 表示时间间隔。两个日期之间的差异的 TimeSpan 对象 TimeSpan P_TimeSpa…

KNN算法

一、KNN算法介绍 KNN 算法,也称 k邻近算法,是 有监督学习 中的 分类算法 。它可以用于分类或回归问题,但它通常用作分类算法。 二、KNN算法流程 1.计算已知类别数据集中的点与当前点的距离 2.按照距离增次序排序 3.选取与当前点距离最小…

星越L_可调悬挂使用讲解

目录 1.可变阻尼设置 1.可变阻尼设置

G-Star 校园开发者计划·黑科大|开源第一课之 Git 入门

万事开源先修 Git。Git 是当下主流的分布式版本控制工具,在软件开发、文档管理等方面用处极大。它能自动记录文件改动,简化合并流程,还特别适合多人协作开发。学会 Git,就相当于掌握了一把通往开源世界的钥匙,以后参与…

html5炫酷3D立体文字效果实现详解

炫酷3D立体文字效果实现详解 这里写目录标题 炫酷3D立体文字效果实现详解项目概述技术实现要点1. 基础布局设置2. 动态背景效果3. 文字渐变效果4. 立体阴影效果5. 悬浮动画效果 技术难点及解决方案1. 文字渐变动画2. 立体阴影效果3. 性能优化 浏览器兼容性总结 项目概述 在这个…

《白帽子讲 Web 安全》之开发语言安全深度解读

目录 引言 1.PHP 安全 1.1变量覆盖 1.2空字节问题 1.3弱类型 1.4反序列化 1.5安全配置 2Java 安全 2.1Security Manager 2.2反射 2.3反序列化 3Python 安全 3.1反序列化 3.2代码保护 4.JavaScript 安全 4.1第三方 JavaScript 资源 4.2JavaScript 框架 5.Node.…

django设置admin的排列顺序,耗3小时【躲坑指南】

django 项目中,这个数据栏目的显示排列顺序我希望更贴近业务 比如要让【商品货品信息】中的9个数据表根据人为规定来进行排序 结果:工程量很大。 能够实现人为的自定义排序 最简单的设置就是给模型添加号数标记 主应用中创建admin–设置了其中一个应用…

macOS使用brew切换Python版本【超详细图解】

目录 一、更新Homebrew仓库 二、安装pyenv 三、将pyenv添加到bash_profile文件中 四、使.bash_profile文件的更改生效 五、安装需要的Python版本 六、设置全局使用的Python版本 七、检查Python版本是否切换成功 pyenv常用命令 一、更新Homebrew仓库 brew update 这个…

【深度学习新浪潮】AI ISP技术与手机厂商演进历史

本文是关于AI ISP(人工智能图像信号处理器)的技术解析、与传统ISP(图像信号处理器)的区别、近三年研究进展,以及各大手机厂商在该领域演进历史的详细报告。本报告综合多个权威来源的信息,力求全面、深入地呈现相关技术发展脉络与行业动态。 第一部分:AI ISP的定义及与传…

如何让节卡机器人精准对点?

如何让节卡机器人精准对点? JAKA Zu 软件主界面主要由功能栏、开关栏、菜单栏构成。 菜单栏:控制柜管理,机器人管理与软件管理组成。主要功能为对控制柜关机、APP 设置、机器人本体设 置、控制柜设置、连接机器人和机器人显示等功能。 开关…

QAI AppBuilder 快速上手(7):目标检测应用实例

YOLOv8_det是YOLO 系列目标检测模型,专为高效、准确地检测图像中的物体而设计。该模型通过引入新的功能和改进点,如因式分解卷积(factorized convolutions)和批量归一化(batch normalization),在…

c#难点整理

1.何为托管代码,何为非托管代码 托管代码就是.net框架下的代码 非托管代码,就是非.net框架下的代码 2.委托的关键知识点 将方法作为参数进行传递 3.多维数组 4.锯齿数组 5.多播委托的使用 6.is运算符 相当于逻辑运算符是 7.as 起到转换的作用 8.可…

二次向用户申请授权

HarmonyOS 5.0.3(15) 版本的配套文档,该版本API能力级别为API 15 Release 文章目录 当应用通过requestPermissionsFromUser()拉起弹框请求用户授权时,用户拒绝授权。应用将无法再次通过requestPermissionsFromUser()拉起弹框,需要用户在系统应…

JetsonNano —— 4、Windows下对JetsonNano板卡烧录刷机Ubuntu20.04版本(官方教程)

介绍 NVIDIA Jetson Nano™ 开发者套件是一款面向创客、学习者和开发人员的小型 AI 计算机。按照这个简短的指南,你就可以开始构建实用的 AI 应用程序、酷炫的 AI 机器人等了。 烧录刷机 1、下载 Jetson Nano开发者套件SD卡映像 解压出.img文件并记下它在计算机上的…

Chapter 8 Charge Pump

Chapter 8 Charge Pump 8.1 Introduction 电荷泵就是capacitive DC–DC converters, 一般把小功率bias generation叫做电荷泵, 传输大功率的称为switched-capacitor DC–DC converters. 但其实两者都是通过电容网络的charge redistribution, 来实现电压的倍增或者倍降. 8.1.…

计算机网络——通信基础和传输介质

物理层任务:实现相邻节点之间比特(0或1)的传输 到了数据链路层之后,它会以帧为单位,把若干个比特交给物理层,物理层需要把这些比特信息转化成信号,在物理传输媒体上进行传输 通信基础基本概念 信…

【架构】单体架构 vs 微服务架构:如何选择最适合你的技术方案?

文章目录 ⭐前言⭐一、架构设计的本质差异🌟1、代码与数据结构的对比🌟2、技术栈的灵活性 ⭐二、开发与维护的成本博弈🌟1、开发效率的阶段性差异🌟2、维护成本的隐形陷阱 ⭐三、部署与扩展的实战策略🌟1、部署模式的本…