Llama模型结构解析(源码阅读)

目录

  • 1. LlamaModel整体结构流程图
  • 2. LlamaRMSNorm
  • 3. LlamaMLP
  • 4. LlamaRotaryEmbedding

  • 参考资料:
    https://zhuanlan.zhihu.com/p/636784644
    https://spaces.ac.cn/archives/8265 ——《Transformer升级之路:2、博采众长的旋转式位置编码》

前言:本次阅读代码位置,在transformers库底下的modeling_llama.py,具体位置在:transformers/models/llama/modeling_llama.py,如下图所示:在这里插入图片描述

1. LlamaModel整体结构流程图

在这里插入图片描述

2. LlamaRMSNorm

  • 代码如下
class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypevariance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return (self.weight * hidden_states).to(input_dtype)
  • RMSNorm的公式如下所示:
    x i 1 n ∑ i = 1 n x i 2 + e p s ∗ w e i g h t i \frac{x_i}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n}{x_i}^2 + eps}} * weight_i n1i=1nxi2+eps xiweighti

    • 其中,公式与代码的对应关系如下:
      在这里插入图片描述

3. LlamaMLP

  • 代码如下:
class LlamaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,hidden_act: str,):super().__init__()self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)self.act_fn = ACT2FN[hidden_act]def forward(self, x):return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  • 流程图:
    在这里插入图片描述

  • 其中输入为x,输出为y

  • 代码中intermediate_size一般比hidden_size大,我们通过在jupyter notebook中打印Llama-13B的模型,可以看到如下所示:
    在这里插入图片描述

  • 总结:MLP模块就是几个nn.Linear的组合

4. LlamaRotaryEmbedding

  • 代码如下

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
  • 具体的使用,还调用了另外两个函数,如下所示:
def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed
  • 注意这里的实现跟原始推导有点区别,这里实现的方式如下图所示:
    在这里插入图片描述

  • 原始推导如下图所示:
    在这里插入图片描述
    具体可以查看作者的博客:👉戳我👈

  • 总结:RoPE就是在attention计算时,K跟Q做内积之前,先给各自注入位置信息。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/119737.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式入门(二)观察者模式

设计模式入门 本系列所有内容参考自《HeadFirst设计模式》。因为书中的代码是采用java语言写的,博主这里用C语言改写。 这里采用讲故事的方式进行讲解。若有错误之处,非常欢迎大家指导。 设计模式:模式不是代码,而针对设计问题的…

深入理解作用域、作用域链和闭包

​ 🎬 岸边的风:个人主页 🔥 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想,就是为了理想的生活 ! ​ 目录 📚 前言 📘 1. 词法作用域 📖 1.2 示例 📖 1.3 词法作用域的…

Python学习教程:进程的调度

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 要想多个进程交替运行,操作系统必须对这些进程进行调度, 这个调度也不是随即进行的,而是需要遵循一定的法则,由此就有了进程的调度算法。 python更多源码/资料/解答/教程等 …

iOS练手项目知识点汇总

基础理解篇 Objective-C是一种面向对象的编程语言,它支持元编程。元编程是指编写程序来生成或操纵其他程序的技术。 Objective-C中,元编程可以使用Objective-C的动态特性来实现。例如可以使用Objective-C的运行时函数来动态地创建类、添加属性和方法等等…

NPM 常用命令(二)

目录 1、npm bugs 1.1 配置 browser registry 2、npm cache 2.1 概要 2.2 详情 2.3 关于缓存设计的说明 2.4 配置 cache 3、 npm ci 3.1 描述 3.2 配置 install-strategy legacy-bundling global-style omit strict-peer-deps foreground-scripts ignore-s…

Doris workload group实战

1.创建测试用户:创建一个用户名为test,密码为test 的用户: create user test% IDENTIFIED BY test;给测试用户赋权:给用户test赋予数据库test.* 权限 grant SELECT_PRIV,LOAD_PRIV,CREATE_PRIV,ALTER_PRIV ON test.* TO test;开…

AIoT+5G改变智慧城市:揭秘智慧公厕的奇妙魅力

AIoT5G的新型智慧城市应用带来了智慧公厕的全新体验。通过智能监测、高速网络、智能调控、智慧管理等技术应用,公厕的舒适性、便捷性和智慧化程度得到了极大提升。可以看到的是,智慧公厕正逐渐激活智慧城市的生活场景,为城市居民带来更好的生…

UmeTrack: Unified multi-view end-to-end hand tracking for VR 复现踩坑记录

在 github 上找到了开源代码:https://github.com/facebookresearch/UmeTrack/tree/main 环境配置 运行第三行,报错,缺少torch。改成先运行第四行,成功。 再运行第三行,报错,required to install pyproj…

RHCE——十七、文本搜索工具-grep、正则表达式

RHCE 一、文本搜索工具--grep1、作用2、格式3、参数4、注意5、示例5.1 操作对象文件:/etc/passwd5.2 grep过滤命令示例 二、正则表达式1、概念2、基本正则表达式2.1 常见元字符2.2 POSIX字符类2.3 示例 3、扩展正则表达式3.1 概念3.2 示例 三、作业1、作业一2、作业…

Redis一主一从Docker方式部署通过keepalived和 sentinel哨兵模式实现高可用

有两台服务器一台是主,master : 172.24.69.180 另外一台是从, slave :172.24.69.181 vip 地址: 172.24.69.185 1、关闭防火墙 两台服务器都关闭防火墙 systemctl disable --now firewalld firewall-cmd --state关闭SELinux setenforce 0 …

uniapp-秋云图表 ucharts echarts 对比与关系

科普: 秋云图表库,包含二种配置属性对应二种js配置文件。 一种是 :echarts.js,一种是 : ucharts。 二者的配置属性不一样! ucharts和echarts对比 ucharts和echarts都是用于数据可视化的开源JavaScript库,它…

考研408 | 【操作系统】终章

I/O设备的基本概念和分类 I/O设备: I/O设备的分类 1.按使用特性: 2.按传输速率分类: 3.按信息交换的单位分类: 总结: I/O控制器 I/O设备的机械部件: I/O设备的电子部件(I/O控制器&#…

工程师是怎样对待开源

工程师如何对待开源 本文是笔者作为一个在知名科技企业内从事开源相关工作超过 20 年的工程师,亲身经历或者亲眼目睹很多工程师对待开源软件的优秀实践,也看到了很多 Bad Cases,所以想把自己的一些心得体会写在这里,供工程师进行…

Python八大排序实现方法

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 如果有什么疑惑/资料需要的可以点击文章末尾名片 1.基数排序 基数排序的基本思想是先将数字按照个位数上数字的大小进行排序, 排序之后再将已经排过序的数字再按照十位数上数字的大小进行排序,依次推…

使用acme,自动续签免费的SSL,无忧http升级https

使用acme自动续签免费的SSL 安装acme.sh颁发域名将证书安装到nginx下配置nginx的ssl自动续签 这里只进行最简单的操作 安装acme.sh 进入你的用户目录,如果你使用root登陆,那么你的用户目录就是 /root/ curl https://get.acme.sh | sh -s emailmyexam…

Python所有方向的学习路线图!!

学习路线图上面写的是某个方向建议学习和掌握的知识点汇总,举个例子,如果你要学习爬虫,那么你就去学Python爬虫学习路线图上面的知识点,这样学下来之后,你的知识体系是比较全面的,比起在网上找到什么就学什…

Silicon Labs BG22、xG24、BG27无线SoC比较及信驰达无线模块选型指南

作为安全、智能无线技术领域的前沿品牌,全球知名IC设计公司——Silicon Labs,在最近几年陆续推出了EFR32BG22、EFR32xG24、EFR32BG27等系列无线SoC。RF-star作为物联网行业领先的无线通信模组厂商,基于Silicon Labs的无线SoC推出了RF-BM-BG22…

iOS开发Swift-5-自动布局AutoLayout-摇骰子App

1.在iOS坐标系中,以向左、向下为正方向。图片以左上角为基准点。 2.打开之前的摇骰子App,对它的界面做一些适应所有iPhone机型的效果。 3.先对上方logo做一个y轴约束和一个宽高约束。 宽高约束: 水平居中: 对y轴进行约束。将虚线点…

window 常用基础命令

0、起步 0-1) 获取命令的参数指引 netstat /? 0-2) 关于两个斜杠: window 文件路径中使用反斜杠:\ linux 文件路径中使用:/ 1、开关机类指令 shutdown /s # 关机shutdown /r # 重启shutdown /l …

自然语言处理 微调大模型ChatGLM-6B

自然语言处理 微调大模型ChatGLM-6B 1、GLM设计原理2、大模型微调原理1、P-tuning v2方案2、LORA方案 1、GLM设计原理 bert的主要任务是随机的去除掉某个单词,使用上下文将其预测出来(相当于完形填空任务); GPT的主要任务是根据前…