【llm对话系统】大模型 Llama 源码分析之归一化方法 RMS Norm

1. 引言

在深度学习中,归一化 (Normalization) 是一种常用的技术,它可以加速模型的训练并提高模型的性能。常见的归一化方法包括 Batch Normalization (BatchNorm)、Layer Normalization (LayerNorm) 等。Llama 模型采用了一种称为 RMS Norm 的归一化方法,它是一种对 LayerNorm 的简化和改进。

本文将深入 Llama 源码,分析 RMS Norm 的实现逻辑,并探讨其相比于其他归一化方法的优势。

2. 归一化方法回顾

2.1 Batch Normalization (BatchNorm)

BatchNorm 对每个 mini-batch 的数据进行归一化,使其均值为 0,方差为 1。它引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。

公式:

y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift

优点:

  • 加速训练。
  • 具有一定的正则化效果。

缺点:

  • 依赖于 batch size,当 batch size 较小时,效果较差。
  • 不适用于 RNN 等序列模型。

2.2 Layer Normalization (LayerNorm)

LayerNorm 对每个样本的特征进行归一化,使其均值为 0,方差为 1。它也引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。

公式:

y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift

优点:

  • 不依赖于 batch size。
  • 适用于 RNN 等序列模型。

缺点:

  • 计算量比 BatchNorm 略大。

3. RMS Norm 原理

RMS Norm (Root Mean Square Normalization) 可以看作是 LayerNorm 的一个特例。它只对输入进行 均方根 (Root Mean Square) 归一化,并保留了可学习的缩放因子,但 去除了偏移因子

公式:

y = x / sqrt(mean(x^2) + epsilon) * scale

其中:

  • x 是输入向量。
  • mean(x^2)x 各元素的平方的平均值。
  • epsilon 是一个很小的常数,用于防止除零错误。
  • scale 是可学习的缩放因子,通常初始化为 1。

与 LayerNorm 的比较:

  • RMS Norm 没有减去均值 (即没有中心化)。
  • RMS Norm 没有偏移因子。

4. Llama 中 RMS Norm 的实现

Llama 源码中 RMS Norm 的实现位于 llama/model.py 文件中,定义在 RMSNorm 类中:

import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):"""初始化 RMSNorm.Args:dim: 输入的维度eps: 用于数值稳定的小常数"""super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):"""执行 RMS 归一化.Args:x: 输入张量 (..., dim)Returns:归一化后的张量 (..., dim)"""return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):"""前向传播.Args:x: 输入张量 (..., dim)Returns:归一化并缩放后的张量 (..., dim)"""output = self._norm(x.float()).type_as(x)return output * self.weight

代码解释:

  1. __init__ 函数:

    • dim:输入的维度。
    • eps:用于数值稳定的小常数,默认为 1e-6
    • weight:可学习的缩放因子,初始化为全 1 的张量。
  2. _norm 函数:

    • 计算输入 x 的均方根的倒数:torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
      • x.pow(2):计算 x 每个元素的平方。
      • .mean(-1, keepdim=True):沿着最后一个维度计算平均值,并保持维度不变。
      • torch.rsqrt():计算平方根的倒数。
    • x 与均方根的倒数相乘,实现归一化。
  3. forward 函数:

    • 调用 _norm 函数进行归一化。
    • 将归一化后的结果与可学习的 weight 相乘,进行缩放。
    • .type_as(x):将结果转换为与输入 x 相同的类型。

使用示例:

# 假设输入维度为 512
dim = 512
rms_norm = RMSNorm(dim)# 模拟一个输入张量
x = torch.randn(1, 10, dim)# 进行 RMS Norm 归一化
y = rms_norm(x)print(y.shape)  # 输出: torch.Size([1, 10, 512])

5. RMS Norm 的优势

  • 计算效率高:RMS Norm 比 LayerNorm 少了均值计算和偏移操作,计算速度更快。
  • 性能相当:实验表明,RMS Norm 的性能与 LayerNorm 相当,甚至在某些任务上略有提升。
  • 更稳定:RMS Norm 对输入的缩放更加鲁棒,因为它只依赖于输入的平方的平均值,而不依赖于输入的均值。

为什么 RMS Norm 可以去掉偏移因子?

在 Transformer 架构中,通常在 RMS Norm 之后会跟一个线性层 (例如,多头注意力机制中的 Q, K, V 投影)。这个线性层可以学习到偏移的效果。因此,RMS Norm 中的偏移因子就显得多余了。

6. 总结

RMS Norm 是一种高效且有效的归一化方法,它通过对 LayerNorm 进行简化,去除了均值计算和偏移因子,提高了计算效率并保持了良好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11412.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.5 高级索引应用:图像处理中的区域提取

2.5 高级索引应用:图像处理中的区域提取 目录/提纲 #mermaid-svg-BI09xc20YqcpUam7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-BI09xc20YqcpUam7 .error-icon{fill:#552222;}#mermaid-svg-BI09xc20…

[免费]微信小程序智能商城系统(uniapp+Springboot后端+vue管理端)【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的微信小程序智能商城系统(uniappSpringboot后端vue管理端),分享下哈。 项目视频演示 【免费】微信小程序智能商城系统(uniappSpringboot后端vue管理端) Java毕业设计_哔哩哔哩_bilibili 项目介绍…

本地部署DeepSeek-R1保姆级教程

近期,我国一款开源模型 DeepSeek-R1以低成本和高性能震撼了全球科技界。该模型的开源性使开发者能够在本地环境中部署和运行,提供了更高的灵活性和控制力。如果你也想在本地部署 DeepSeek-R1,可以参考以下完整的教程,涵盖Mac 版本…

仿真设计|基于51单片机的贪吃蛇游戏

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现(protues8.7) 程序(Keil5) 全部内容 资料获取 具体实现功能 利用单片机8*8点阵实现贪吃蛇游戏的控制。 仿真演示视频: 51-基于51单片机的贪吃蛇游…

【4Day创客实践入门教程】Day2 探秘微控制器——单片机与MicroPython初步

Day2 探秘微控制器——单片机与MicroPython初步 目录 Day2 探秘微控制器——单片机与MicroPython初步MicroPython语言基础开始基础语法注释与输出变量模块与函数 单片机基础后记 Day0 创想启程——课程与项目预览Day1 工具箱构建——开发环境的构建Day2 探秘微控制器——单片机…

ubuntu 下使用deepseek

安装Ollama sudo snap install ollama 执行 ollama run deepseek-coder 然后进行等待。。。

消息队列应用示例MessageQueues-STM32CubeMX-FreeRTOS《嵌入式系统设计》P343-P347

消息队列 使用信号量、事件标志组和线标志进行任务同步时,只能提供同步的时刻信息,无法在任务之间进行数据传输。要实现任务间的数据传输,一般使用两种方式: 1. 全局变量 在 RTOS 中使用全局变量时,必须保证每个任务…

本地缓存~

前言 Caffeine是使用Java8对Guava缓存的重写版本,在Spring Boot 2.0中取而代之,基于LRU算法实现,支持多种缓存过期策略。 以下摘抄于https://github.com/ben-manes/caffeine/wiki/Benchmarks-zh-CN 基准测试通过使用Java microbenchmark ha…

Unity Shader Graph 2D - 角色身体电流覆盖效果

在游戏中,通常会有游戏角色受到“电击”的效果,此时游戏角色身体上会覆盖有电流,该效果能表明游戏角色的当前状态,让玩家能够获得更直观更好的体验。 那么如何实现呢 首先创建一个ShaderGraph文件,命名为Current,再创建对应的材质球M_Current。 基础的资源显示 老规矩,…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.9 广播陷阱:形状不匹配的深层隐患

2.9 广播陷阱:形状不匹配的深层隐患 目录 #mermaid-svg-F0AgBChfSCGzOqa7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-F0AgBChfSCGzOqa7 .error-icon{fill:#552222;}#mermaid-svg-F0AgBChfSCGzOqa7 …

解锁豆瓣高清海报(二) 使用 OpenCV 拼接和压缩

解锁豆瓣高清海报(二): 使用 OpenCV 拼接和压缩 脚本地址: 项目地址: Gazer PixelWeaver.py pixel_squeezer_cv2.py 前瞻 继上一篇“解锁豆瓣高清海报(一) 深度爬虫与requests进阶之路”成功爬取豆瓣电影海报之后,本文将介绍如何使用 OpenCV 对这些海报进行智…

vue入门到实战 二

目录 2.1 计算属性computed 2.1.1什么是计算属性 2.1.2 只有getter方法的计算属性 2.1.3 定义有getter和setter方法的计算属性 2.1.4 计算属性和methods的对比 2.2 监听器属性watch 2.2.1 watch属性的用法 2.2.2 computed属性和watch属性的对比 2.1 计算属性computed…

【DeepSeek】本地快速搭建DeepSeek

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 博客内容主要围绕: 5G/6G协议讲解 高级C语言讲解 Rust语言讲解 文章目录 本地快速搭建DeepSeek一、安装及配置ollama二、DeepSeek模型…

Spring WebFlux揭秘:下一代响应式编程框架,与Spring MVC有何不同?

Spring WebFlux和Spring MVC都是Spring家族里的成员,它们都能帮助我们开发Web应用,但工作方式有所不同。 可以把Spring MVC想象成一个服务员,每次有客人(请求)来,它就会专门找一个服务员(线程&a…

基于微信小程序的实习记录系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

MySQL5.5升级到MySQL5.7

【卸载原来的MySQL】 cmd打开命令提示符窗口(管理员身份)net stop mysql(先停止MySQL服务) 3.卸载 切换到原来5.5版本的bin目录,输入mysqld remove卸载服务 测试mysql -V查看Mysql版本还是5.5 查看了环境变量里的…

TensorFlow 简单的二分类神经网络的训练和应用流程

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括: 1. 数据准备与预处理 2. 构建模型 3. 编译模型 4. 训练模型 5. 评估模型 6. 模型应用与部署 加载和应用已训练的模型 1. 数据准备与预处理 在本例中,数据准备是通过两个 Numpy 数…

使用朴素贝叶斯对散点数据进行分类

本文将通过一个具体的例子,展示如何使用 Python 和 scikit-learn 库中的 GaussianNB 模型,对二维散点数据进行分类,并可视化分类结果。 1. 数据准备 假设我们有两个类别的二维散点数据,每个类别包含若干个点。我们将这些点分别存…

AI视频编码器(3.2) 《Swin Transformer V2: Scaling Up Capacity and Resolution》

arxiv链接自监督训练用到了SimMIM 论文链接。我觉得,SimMIM与MAE的区别在于,前者只是一个1-layer的prediction head,而后者是多层transformer结构的decoder。可参考Swin Transformer V2(CVPR 2022)论文与代码解读。总结 图中展示了三个创新,从左到右有三处红色结构,分别…

前端进阶:深度剖析预解析机制

一、预解析是什么? 在前端开发中,我们常常会遇到一些看似不符合常规逻辑的代码执行现象,比如为什么在变量声明之前访问它,得到的结果是undefined,而不是报错?为什么函数在声明之前就可以被调用&#xff1f…