Llama架构及代码详解

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):def __init__(self, config: ModelArgs):super().__init__()self.config = config# embedding层self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)# RMSNormself.norm = RMSNorm(config.dim, eps=config.norm_eps)# n层Transformerself.layers = torch.nn.ModuleList()for i in range(self.config.n_layers):self.layers.append(TransformerBlock(config))def forward(self, tokens):# 进行token的嵌入编码h = self.tok_embeddings(tokens)# decoder架构需要生成一个maskseqlen = h.shape[1]mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)mask = torch.triu(mask, diagonal=1)# 进行n层Transformerfor i in range(self.config.n_layers):h = self.layers[i](h, mask)# 进行RMSNormtoken_embeddings = self.norm(h)return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):def __init__(self, dim, eps):super().__init__()self.eps = epsself.weight = torch.nn.Parameter(torch.ones(dim))def _norm(self, tensor):return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, tensor):output = self._norm(tensor)return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):def __init__(self, config):super().__init__()self.config = config# 多头注意力层self.attention = Attention(config)# Norm层self.attention_normal = RMSNorm(config.dim, config.norm_eps)self.ffn_norm = RMSNorm(config.dim, config.norm_eps)# 全连接层self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)def forward(self, embeddings, mask):# normh = self.attention_normal(embeddings)# attentionh = self.attention(h, mask)# add & normh = self.ffn_norm(h + embeddings)# fnnf = self.ffn(h)# addreturn f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):def __init__(self, config):super().__init__()self.config = configself.n_head = config.n_headsself.dim = config.dim // self.n_headself.k = torch.nn.Linear(config.dim, config.dim)self.q = torch.nn.Linear(config.dim, config.dim)self.v = torch.nn.Linear(config.dim, config.dim)def forward(self, embeddings, mask):bsz, seq_len, dim = embeddings.shapek_embeddings = self.k(embeddings)q_embeddings = self.q(embeddings)v_embeddings = self.v(embeddings)n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)/ math.sqrt(self.dim), dim=-1)n_embeddings = scores @ n_v_embeddingsembeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)return embeddings
class FeedForwad(torch.nn.Module):def __init__(self, dim, hidden_dim):super().__init__()self.linear1 = torch.nn.Linear(dim, hidden_dim)self.linear2 = torch.nn.Linear(dim, hidden_dim)self.linear3 = torch.nn.Linear(hidden_dim, dim)def forward(self, embeddings):gate = torch.nn.functional.silu(self.linear1(embeddings))up_proj = self.linear2(embeddings) * gatereturn self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):# 计算所有嵌入位置的旋转角度all_theta = compute_all_theta(dim, m, base)# 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标# 1、将嵌入投影到复数平面embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)embedding_complex_pair = torch.view_as_complex(embedding_real_pair)# 2、将旋转角度投影到复数平面all_theta = all_theta[: embedding.shape[-2]]theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)# 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标rotated_complex_embedding = embedding_complex_pair * theta_complex_pair# 4、将复数平面的嵌入投影到实数平面rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)return rotated_real_embeddingdef compute_all_theta(dim, m, base):theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))m = torch.arange(0, m)all_theta = torch.outer(m, theta)return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5rope_theta: float = 500000max_batch_size: int = 32max_seq_len: int = 2048use_scaled_rope: bool = True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/471077.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sol机器人pump机器人如何实现盈利的?什么是Pump 扫链机器人?

什么是Pump 扫链机器人,它的盈利逻辑优化策略是什么? Pump 扫链机器人,通过智能化、自动化的买卖操作帮助投资者实现快速盈利。在此基础上,我们对该机器人的盈利逻辑进行了深度优化,涵盖了买入策略和止盈策略的各个方面…

三维测量与建模笔记 - 特征提取与匹配 - 4.2 梯度算子、Canny边缘检测、霍夫变换直线检测

从Roberts交叉算子的卷积核可以看出,它实际计算了对角线上元素之间的差值。 prewitt算子实际是对整行或整列、或者对角线两侧的像素进行差分计算。 Sobel算子改进了一下Prewitt算子,增加了权重,中心位置的像素权重为2。 中心权重为4的Laplac…

【2024软考架构案例题】你知道 Es 的几种分词器吗?Standard、Simple、WhiteSpace、Keyword 四种分词器你知道吗?

👉博主介绍: 博主从事应用安全和大数据领域,有8年研发经验,5年面试官经验,Java技术专家,WEB架构师,阿里云专家博主,华为云云享专家,51CTO 专家博主 ⛪️ 个人社区&#x…

1.7 JS性能优化

从输入url到页面加载完成都做了些什么 输入 URL - 资源定位符 http://www.zhaowa.com - http 协议 域名解析 https://www.zhaowa.com > ip 1. 切HOST? > 浏览器缓存映射、系统、路由、运营商、根服务器 2. 实际的静态文件存放? 大流量 > 多个…

Linux基础1

Linux基础1 Linux基础1学习笔记 ‍ 声明! ​​​学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章 笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他…

【安全通信】告别信息泄露:搭建你的开源视频聊天系统briefing

文章目录 前言1.关于briefing2.本地部署briefing3.使用briefing4.cpolar内网穿透工具安装5.创建远程连接公网地址6.固定briefing公网地址 前言 在这个信息爆炸的时代,视频聊天几乎成了我们日常沟通的标配。但你是否曾在视频会议中感到不安,担心自己的私…

深度学习——优化算法、激活函数、归一化、正则化

文章目录 🌺深度学习面试八股汇总🌺优化算法方法梯度下降 (Gradient Descent, GD)动量法 (Momentum)AdaGrad (Adaptive Gradient Algorithm)RMSProp (Root Mean Square Propagation)Adam (Adaptive Moment Estimation)AdamW 优化算法总结 经验和实践建议…

Thread类及常见方法

目录 一、Thread常见构造方法 二、Thread常见属性 三、Thread常见方法 start() 获取当前线程 中断线程 join() 一、Thread常见构造方法 Thread类是JVM用来管理线程的一个类,每个线程都有唯一一个Thread对象与之对应,JVM会将这些对象组织起来&…

优化时钟网络之时钟抖动

Note:文章内容以Xilinx 7系列FPGA进行讲解 1、什么是时钟抖动 时钟抖动就是时钟周期之间出现的偏差。比如一个时钟周期为10ns的时钟,理想情况下,其上升沿会出现在0ns,10ns,20ns时刻,假设某个上升沿出现的时…

Vector 深度复制记录

有的时候数据得复制过去 有个疑问,自动分配内存吗? 不是估计有变化, 得在看看 指针作为值复制了 … … 挺好,修改原有的值 x86 的 SIM 程序 还有点问题 ; 无法直接绕过硬件错误 。。。 x86 gdb 没有问题 就是运行出现了问题,怎么解决;正常初始化没有问题…

贪心算法day03(最长递增序列问题)

目录 1.最长递增三元子序列 2.最长连续递增序列 1.最长递增三元子序列 题目链接:. - 力扣(LeetCode) 思路:我们只需要设置两个数进行比较就好。设a为nums[0],b 为一个无穷大的数,只要有比a小的数字就赋值…

基于Java Web的传智播客crm企业管理系统的设计与实现

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

马斯克万卡集群AI数据中心引发的科技涟漪:智算数据中心挑战与机遇的全景洞察

一、AI 爆发重塑数据中心格局 随着AI 技术的迅猛发展,尤其是大模型的崛起,其对数据中心产生了极为深远的影响。大模型以其数以亿计甚至更多的参数和对海量数据的处理需求,成为了 AI 发展的核心驱动力之一,同时也为数据中心带来了…

机器学习在医疗健康领域的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 机器学习在医疗健康领域的应用 机器学习在医疗健康领域的应用 机器学习在医疗健康领域的应用 引言 机器学习概述 定义与原理 发展…

学法减分交管12123模拟练习小程序源码前端和后端和搭建教程

交管推出个学法减分,每个驾驶员可以把被扣的6分,以看视频答题的形式学习回来,然后答题这个一共二十道题每道题60秒,有好多人不会,用咱们的小程序就可以模拟练习强化练习,还有拍照识别题目找到正确答案&…

AI大模型开发架构设计(18)——基于大模型构建企业知识库案例实战

文章目录 1 LLM 大模型在工作中的实际应用以及局限性LLM 大模型工作中实际应用大模型2点局限性 2 基于大模型和向量数据库的企业级知识库架构剖析向量数据库向量数据库选型知识库文档检索增强(Retrieval Augmented Generation)向量数据库应用技术总体架构向量数据库应用离线索引…

jmeter介绍、使用方法、性能测试、现参数化和数据驱动、分布式测试、压力测试、接口测试

目录 1.JMeter的组件介绍 2.JMeter介绍和使用方法 3.使用JMeter进行性能测试 4.JMeter如何实现参数化和数据驱动 5.使用JMeter进行分布式测试 6.使用JMeter完成压力测试 7.使用JMeter完成接口测试 下载并安装JMeter:从官方网站(https://jmeter.ap…

Zotero 6.0 安装包及安装教程

Zotero的界面友好,操作简单,对于科研小白来说,是一款非常实用的文献管理软件。它不仅可以帮助用户精确获取、整理、引用文献,而且在学术实践中不可或缺的一环。 安 装 步 骤 压缩包文件,鼠标右击解压得到安装包。 仅用…

Docker 篇-Docker 详细安装、了解和使用 Docker 核心功能(数据卷、自定义镜像 Dockerfile、网络)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 Docker 概述 1.1 Docker 主要组成部分 1.2 Docker 安装 2.0 Docker 常见命令 2.1 常见的命令介绍 2.2 常见的命令演示 3.0 数据卷 3.1 数据卷常见的命令 3.2 常见…

华为大变革?仓颉编程语言会代替ArkTS吗?

在华为鸿蒙生态系统中,编程语言的选择一直是开发者关注的焦点。近期,华为推出了自研的通用编程语言——仓颉编程语言,这引发了关于仓颉是否会取代ArkTS的讨论。本文将从多个角度分析这两种语言的特点、应用场景及未来趋势,探讨仓颉…