最强英文开源模型LLaMA架构探秘,从原理到源码

导读:
LLaMA 65B是由Meta AI(原Facebook AI)发布并宣布开源的真正意义上的千亿级别大语言模型,发布之初(2023年2月24日)曾引起不小的轰动。LLaMA的横空出世,更像是模型大战中一个搅局者。虽然它的效果(performance)和GPT-4仍存在差距,但GPT-4毕竟是闭源的商业模型,LLaMA系列的开源给了世界上其他团队研究和使用千亿大语言模型的机会。

读完本文,你可能觉得LLaMA会开源并不令人惊讶,因为它的架构可以说是站在巨人肩膀上摘苹果——基本上可以说使用其他模型的组件作为“积木”搭了一个新模型出来,并没有太多实质意义上的创新,但这种敢于开源的勇气和做法使得LLaMA足以在大语言模型上的开源发展历程上成为一个标志性的里程碑

在这里插入图片描述Introducing LLaMA: A foundational, 65-billion-parameter large language model
LLaMA开源地址:https://github.com/facebookresearch/llama(llama在llama_v1代码分支上)

正文

llama英文中指大羊驼,是一种分布在南美洲的骆驼科羊驼属动物

在这里插入图片描述

LLaMA是一个基于transformer架构的大语言模型,同Google的PaLM一样,针对原始的transformer架构进行了一些“小改进”。整体而言,初版LLaMA的架构和原始transformer有3个大的差异点:

  • 前置归一化(Pre-Normalization)[受GPT3启发]:为了提升训练时的稳定性,LLaMA归一化了transformer子层的输入而不是输出,具体使用的正则化方法是RMSNorm
  • SwiGLU激活函数 [受PaLM启发]:LLaMA使用了和PaLM一样的SwiGLU激活函数来替代原始的ReLU以提升模型效果。细节上,LLaMA使用dimension为 2 3 4 d \frac{2}{3}4d 324d而不是 4 d 4d 4d
  • 旋转位置编码(Rotary Embedding, Rotary Position Embedding)[受GPTNeo启发]:LLaMA没有使用绝对位置编码(BERT的位置 s i n sin sin c o s cos cos编码是一种绝对位置编码),而是使用了相对位置编码RoPE

除此之外,一些训练上的细节:

  • LLaMA使用adamW优化器,设置超参数 β 1 = 0.9 \beta_1=0.9 β1=0.9 β 2 = 0.95 \beta_2=0.95 β2=0.95
  • 使用cosine学习率调度,即最终的学习率是最大学习率的10%
  • 权重衰减设置为0.1
  • 梯度剪枝设置为1
  • 2000步热启动(warmup)。
  • 不同尺寸的模型使用不同的学习率batch size

下面我们来深入了解一下架构上3个差异点的技术细节。

RMSNorm

详细推到过程见原论文:Root Mean Square Layer Normalization

前置归一化(Pre-Normalization)可以使得训练过程更加稳定,这种设计将第一层的归一化设在多头注意力层之前,第二层的归一化移动到全连接层之前,同时将shortcut设置在multi-attention层与FNN层之间。如下如所示:
在这里插入图片描述
LLaMA在归一化过程中使用RMSNorm,针对输入向量 a a a,RMSNorm的计算公式如下:
R M S ( a ) = 1 n ∑ i = 1 n a i 2 RMS(a)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}}a_i^2 RMS(a)=n1i=1n ai2
a i ˉ = a i R M S ( a ) \bar{a_i}=\frac{a_i}{RMS(a)} aiˉ=RMS(a)ai

相较于原始的RMSNorm,LLaMA加入了一个缩放因子 g i g_i gi和一个偏移参数 b i b_i bi(均为可学习参数),最终得到:
a i ˉ = a i R M S ( a ) g i + b i \bar{a_i}=\frac{a_i}{RMS(a)}g_i+b_i aiˉ=RMS(a)aigi+bi

HuggingFace Transformer 库中的LLaMA RMSNorm实现如下:

class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = eps # eps 防止取倒数之后分母为 0def forward(self, hidden_states):input_dtype = hidden_states.dtypevariance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # weight 是末尾乘的可训练参数, 即 g_ireturn (self.weight * hidden_states).to(input_dtype)

SwiGLU激活函数

详细推导过程见原论文:GLU Variants Improve Transformer

LLaMA使用的SwiGLU激活函数同时也在PaLM等多个LLM应用,相较于ReLU能在很多评测数据集上提升明显。

LLaMA全连接层使用SwiGLU激活函数的计算公式如下:
F F N S w i G L U ( x , W , V , W 2 ) = S w i G L U ( x , W , V ) W 2 FFN_{SwiGLU}(x,W,V,W_2)=SwiGLU(x,W,V)W_2 FFNSwiGLU(x,W,V,W2)=SwiGLU(x,W,V)W2
S w i G L U ( x , W , V ) = S w i s h β ( x W ) ⊗ x V SwiGLU(x,W,V)=Swish_\beta(xW) \otimes xV SwiGLU(x,W,V)=Swishβ(xW)xV
S w i s h β = x σ ( β x ) Swish_\beta=x\sigma(\beta x) Swishβ=xσ(βx)

其中 σ \sigma σ即sigmoid函数。

S w i s h β Swish_\beta Swishβ函数在参数 β \beta β取值不同时形状不同,如下图:
在这里插入图片描述

  • β → 0 \beta \rightarrow 0 β0时, S w i s h β → 直线 y = x Swish_\beta \rightarrow 直线 y=x Swishβ直线y=x
  • β → ∞ \beta \rightarrow \infin β时, S w i s h β → R e L U Swish_\beta \rightarrow ReLU SwishβReLU

LLaMA中 β = 1 \beta=1 β=1,维度缩放为 2 3 4 d \frac{2}{3}4d 324d
在这里插入图片描述

SwishGLU一定程度上引入了Gating机制,原论文实验结果证明了基于Gating的方法普遍优于单纯的激活函数(ReLU
/GELU/Swish)

旋转位置编码 RoPE (Rotary Position Embeddings)

详细推导过程见原论文:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING

LLaMA使用RoPE代替原有的绝对位置编码(指BERT的 s i n sin sin c o s cos cos位置编码是按固定值计算的,逻辑上表示的位置也是固定的),以取得更好效果。RoPE的数学推导借助了复数的思想,原作者期望通过数学方法基于绝对位置编码的方式实现相对位置编码,进一步讲,存在向量 q q q k k k,通过如下运算可以给它们添加绝对位置信息:
q ~ m = f ( q , m ) , k ~ n \tilde{\mathbf{q}}_m=f(\mathbf{q},m),\tilde{\mathbf{k}}_n q~m=f(q,m),k~n
q ~ m \tilde{\mathbf{q}}_m q~m k ~ n \tilde{\mathbf{k}}_n k~n具备了 m m m n n n的绝对位置信息。

f ( q , m ) ) f(\mathbf{q},m)) f(q,m))经推导如下:
f ( q , m ) ) = R f ( q , m ) e i Θ f ( q , m ) = ∣ ∣ q ∣ ∣ e i ( Θ ( q ) + m θ ) = q e i m θ f(\mathbf{q},m))=R_f(\mathbf{q},m)e^{i\varTheta_f(\mathbf{q},m)}=||\mathbf{q}||e^{i(\varTheta(q)+m\theta)}=\mathbf{q}e^{im\theta} f(q,m))=Rf(q,m)eiΘf(q,m)=∣∣q∣∣ei(Θ(q)+mθ)=qeimθ
(详细推导过程参见源论文)

根据复数乘法的几何意义,上述变换实际上对应向量旋转操作,因而得名“旋转位置编码”,矩阵形式可能能提供不一样的理解:
f ( q , m ) ) = ( c o s m θ − s i n c o s m θ s i n m θ c o s m θ ) ( q 0 q 1 ) f(\mathbf{q},m))=\begin{pmatrix} cos \ m\theta & -sin \ cos \ m\theta \\ sin \ m\theta & cos \ m\theta \end{pmatrix} \begin{pmatrix} \mathbf{q_0} \\ \mathbf{q_1} \end{pmatrix} f(q,m))=(cos mθsin mθsin cos mθcos mθ)(q0q1)

根据内积满足线性叠加的性质,任意偶数维上的RoPE,都可以表示为二维情形的拼接,进一步将公式转化为:
在这里插入图片描述
上述稀疏矩阵可以使用逐位相乘 ⊗ \otimes 加快计算速度,因而RoPE在HuggingFace Transformer 库中代码实现如下所示:

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation# in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. # Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation# in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

不同参数规模的LLaMA模型

基于我们前面讲解的内容,可以实现一个完整的LLaMA Decoder,HuggingFace Transformer库中的实现代码实现如下所示:

class LlamaDecoderLayer(nn.Module):def __init__(self, config: LlamaConfig):super().__init__()self.hidden_size = config.hidden_sizeself.self_attn = LlamaAttention(config=config)self.mlp = LlamaMLP( hidden_size=self.hidden_size,intermediate_size=config.intermediate_size,hidden_act=config.hidden_act,)self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)def forward(self, hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attentionhidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,)hidden_states = residual + hidden_states# Fully Connectedresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,) if use_cache:outputs += (present_key_value,) return outputs

再按架构即可可以实现整个LLaMA模型。

Meta一共发布了4种尺寸的LLaMA,不同尺寸模型的的细节区别如下:
在这里插入图片描述

预训练 Pre-Training

预训练数据集对模型效果有深刻影响,LLaMA使用的混合数据集配比以及大小如下:
在这里插入图片描述
预训练数据集经token化之后总计1.4T个token,对于大多数预训练token仅使用一次,但Wikipedia和Books数据集训练了2轮。

指令精调 Instruction Finetuning

在LLaMA论文里,原作者尝试对LLaMA做了一个简单的指令精调,结果在MMLU数据集上有5.4%提升:
在这里插入图片描述
指令精调的细节参见:Scaling Instruction-Finetuned Language Models,作者为了针对模型效果作对比采用了同样的流程。

结语

LLaMA的架构探秘止步于此。

随着大模型的参数逐步增大,模型的整体架构已不足以对最终效果决定性影响,反而数据集和架构上的一些小细节决定了模型的最终效果。LLaMA虽然没有特别亮眼的创新,但是它的一些实验性的结论,也对后面的模型设计和训练提供了良好的借鉴意义。作为第一个开源的由业界顶尖公司发布的大模型,LLaMA实际上起到了大模型开源进程的奠基作用。

希望未来能看到越来越多的大模型开源,也希望自然语言处理能真正为人类的生产力带来更多可实地落地的突破。

参考文献

  1. LLaMA: Open and Efficient Foundation Language Models
  2. Introducing LLaMA: A foundational, 65-billion-parameter large language model
  3. 大规模语言模型:从原理到实践(复旦NLP教材)
  4. 大规模预训练语言模型方法与实践 (崔一鸣 北京·BAAI 2023年8月26日)
  5. Root Mean Square Layer Normalization
  6. GLU Variants Improve Transformer
  7. ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING
  8. Scaling Instruction-Finetuned Language Models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/167383.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

doc与docx文档转html,格式样式不变(包含图片转换)

最近做一个富文本的需求,要求把文档内容转换到富文本内,文档中的格式也好,样式也好,图片啥的都要一致展示;踩了不少坑,据说word文档其实是一个压缩包,我不是特别清楚但是也能理解,自…

pymysql连接Mariadb/Mysql出现错误(配置正确情况下)解决办法

场景:在kali中使用python中pymysql对Mariadb进行连接,在整个过程中配置全部正确,但是就是无法进行连接,提示结果如下: Access denied for user rootlocalhost解决办法:进入数据库中,将默认密码…

自然语言处理---huggingface平台使用指南

1 huggingface介绍 Huggingface总部位于纽约,是一家专注于自然语言处理、人工智能和分布式系统的创业公司。他们所提供的聊天机器人技术一直颇受欢迎,但更出名的是他们在NLP开源社区上的贡献。Huggingface一直致力于自然语言处理NLP技术的平民化(democr…

C# 文件 校验:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64

文件 校验 算法:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64 (免费) 编程语言:C# 功能:文件 哈希 属性 校验算法:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64。 下载(免费):htt…

瑞萨e2studio(27)----使用EZ-CUBE3烧录

瑞萨e2studio.27--使用EZ-CUBE3烧录 概述视频教学样品申请引脚配置EZ-CUBE3 仿真器开关设置对RA族MCU进行Flash编程蓝色 LED 指示灯的状态信息 概述 EZ-CUBE3(CYRCNEZCUBE03)是具有Flash存储器编程功能的片上调试仿真器,可以用于调试MCU程序…

Vue2基础知识(一) 认识Vue

💌 所属专栏:【Vue2】😀 作 者:长安不及十里💻工作:目前从事电力行业开发🌈目标:全栈开发🚀 个人简介:一个正在努力学技术的Java工程师,专注基础和…

【Javascript】构造函数之new的作用

目录 new的作用 把对象返回了回来 无new 有new 把构造函数的this指向了要返回的对象 无new​编辑 有new new的执行流程 new的作用 创建了新空对象将构造函数的作用域赋值给新对象(this指向新对象)执行构造函数代码 (为这个新对象添加属性)返回新对…

Java EE-使用Servlet搭建一个简单的前后端交互程序

上述前端和后端的代码如下&#xff1a; 前端&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"vie…

优测云测试平台 | 有效的单元测试

一、前言 本文作者提出了一种评价单元测试用例的质量的思路&#xff0c;即判断用例是否达到测试的“四大目标”。掌握识别好的用例的能力&#xff0c;可以帮助我们高效地写出高质量的测试用例。 评判冰箱的好坏&#xff0c;并不需要有制造一台冰箱的能力。在开始写测试用例之…

经典链表问题:解析链表中的关键挑战

这里写目录标题 公共子节点采用集合或者哈希采用栈拼接两个字符串差和双指针 旋转链表 公共子节点 例如这样一道题&#xff1a;给定两个链表&#xff0c;找出它们的第一个公共节点。 具体的题目描述我们来看看牛客的一道题&#xff1a; 这里我们有四种解决办法&#xff1a; …

晶振与晶体

文章目录 基础知识无源晶振 & 有源晶振 博文链接 基础知识 无源晶振 & 有源晶振 博文链接 晶振原理解析

Flutter的Constructors for public widgets should have a named ‘key‘ parameter警告

文章目录 问题描述问题原因修改方法详细解释 问题描述 Constructors for public widgets should have a named ‘key’ parameter. 如下图&#xff1a; 原本的代码 class MyTabPage extends StatefulWidget {override_MyTabPageState createState() > _MyTabPageState(…

大数据测试用例分析

基于大数据分析&#xff0c;对业务系统产生的日志进行智能分析&#xff0c;能够识别日志中的接口、参数、业务流&#xff0c;并依据分析的结果生成测试用例。 问题与背景 业务复杂 业务系统的复杂性&#xff0c;对测试人员的业务能力提出严格要求&#xff0c;加重测试成本。 …

【深度学习-第4篇】使用MATLAB快速实现CNN多变量回归预测

上一篇我们讲了使用CNN进行分类的MATLAB代码。 这一篇我们讲CNN的多变量回归预测。 是的&#xff0c;同样是傻瓜式的快速实现。 一、什么是多变量回归预测 多变量回归预测则是指同时考虑多个输入特征进行回归预测。举几个例子&#xff1a; 房价预测&#xff1a;给定一组房…

搜索问答技术学习:基于知识图谱+基于搜索和机器阅读理解(MRC)

目录 一、问答系统应用分析 二、搜索问答技术与系统 &#xff08;一&#xff09;需求和信息分析 问答需求类型 多样的数据源 文本组织形态 &#xff08;二&#xff09;主要问答技术介绍 发展和成熟度分析 重点问答技术基础&#xff1a;KBQA和DeepQA KBQA&#xff08;…

CentOS 系统安装和使用Docker服务

系统环境 使用下面的命令&#xff0c;可以查看CentOS系统的版本。 lsb_release -a结果&#xff1a; 说明我的系统是7.9.2009版本的 安装Docker服务 依次执行下面的指令&#xff1a; yum install -y yum-utilsyum install -y docker即可安装docker服务 如果这样安装不成功…

[ Windows-Nginx ]Windows服务器,Tomcat容器部署项目,整合Nginx

一、官网下载Nginx http://nginx.org/en/download.html 稳定版&#xff1a;windows的stable版本 注意&#xff1a;Nginx安装包不要放在中文目录下 二、conf目录下&#xff0c;修改nginx.conf文件 修改Nginx服务端口&#xff1a; 默认端口为80&#xff0c;即外界访问的入口…

mysql优化之explain详解

mysql的explain&#xff08;执行计划&#xff09;用于解释sql的执行的过程&#xff0c;然后把sql的执行过程用一张表格表示出来&#xff0c;它并不真正的执行sql&#xff0c;如下图。explain能够为我们优化sql提供很好参考作用。 下面我来看下执行计划表中各个字段是什么意思 i…

FFmpeg和rtsp服务器搭建视频直播流服务

下面使用的是ubuntu的&#xff0c;window系统可以参考&#xff1a; 通过rtsp-simple-server和ffmpeg实现录屏并发布视频直播_rtsp simple server_病毒宇宇的博客-CSDN博客 一、安装rtsp-simple-server &#xff08;1&#xff09;下载rtsp-simple-server 下载地址&#xff1a;R…

第 368 场 LeetCode 周赛题解

A 元素和最小的山形三元组 I 前后缀操作&#xff1a;求出前后缀上的最小值数组&#xff0c;然后枚举 j j j class Solution { public:int minimumSum(vector<int> &nums) {int n nums.size();vector<int> l(n), r(n);//l[i]min{nums[0],...,nums[i]}, r[i]mi…