导读:本文可以看作是对分析transformer模型的参数量、计算量、中间激活、KV cache的详细说明
定性分析
GPU上都存了哪些东西
首先我们来从全局整体的角度看一看,在训练阶段GPU显存上都有哪些内容:
- Model States:模型训练过程中必须存储的states
- params(下面有时也叫做weights):模型参数,记参数量为 Φ \Phi Φ
- grads:模型梯度,梯度数量同参数量 Φ \Phi Φ
- optimizer states:Adam优化器中的momentum和variance,数量分别是 Φ \Phi Φ,共 2 Φ 2\Phi 2Φ
- Residual States:模型训练过程中,中间临时的、动态产生的states
- activation:中间激活值,这个部分可能在训练过程中占据很大一部分显存,下面会详细分析。但是激活值不是必须存储的,可以使用重计算(recompute,也叫做activation checkpoint),在反向算梯度的时候,再重新算一遍,当然计算增加了,时间换空间,实际使用中可以部分选择性的进行重计算。
- temporary buffers:临时存储,比如cuda、nccl等临时申请的显存。
- unusable fragment memory:内存碎片导致的内存浪费
推理阶段就相对简单一些,最主要的是Model States中的params和Residual States中的activation。
参考:图解大模型训练之:数据并行下篇( DeepSpeed ZeRO,零冗余优化)
混合精度训练
上面只是列出了训练过程中,显存中存放的内容和保存的数值数量,但是实际训练过程中,为了节省显存,以及考虑到训练过程中间某些过程对精度不是特别敏感,所以中间有些部分会使用fp32,有些部分会使用fp16/bf16。下面以Megatron为例,简单分析混合精度训练的一个大致流程。
首先我们来看一下不使用混合精度训练的场景,数值精度全使用fp32,作为一个分析的baseline。具体过程是:
占用显存为: 4 Φ 4\Phi 4Φ(fp32 weights)+ 4 Φ 4\Phi 4Φ(fp32 momentum)+ 4 Φ 4\Phi 4Φ(fp32 variance)+ 4 Φ 4\Phi 4Φ(fp32 grad)+fp32 activation(可能很大)= 16 Φ 16\Phi 16Φ Bytes + fp32 activation(4代表fp32的4Bytes,2代表fp16/bf16的2Bytes)
如果使用fp16的混合精度训练(bf16应该也可以,但是实际Megatron有点不同,下面会提到),具体过程是:
占用显存为: 4 Φ 4\Phi 4Φ(fp32 weights)+ 4 Φ 4\Phi 4Φ(fp32 momentum)+ 4 Φ 4\Phi 4Φ(fp32 variance)+ 2 Φ 2\Phi 2Φ(fp16 grad)+ 2 Φ 2\Phi 2Φ(fp16 scaled grad)+ 4 Φ 4\Phi 4Φ(fp32 unscaled and cliped grad)+fp16 activation(可能很大)= 20 Φ 20\Phi 20Φ Bytes + fp16 activation
需要说明的有两点:
- 当fp16 scaled grad转为为fp32 unscaled and cliped grad后,fp16 scaled grad就没用了,但是此时Megatron中仍然保留着一份fp16 scaled grad,所以显存占用中这两部分都会计算在内,这也符合Megatron offical readme中的描述:
-
注意到上面流程中多了一个scale/unscale的操作,这叫做“loss scaling”
在使用混合精度训练时,如果直接使用fp16的grad来更新fp16的梯度,一是会产生舍入误差(比如梯度很小,权重更新后,由于精度不够,累加上的lr * grad被舍入,权重没变,一句话来说就是大数吃小数),二是会产生梯度下溢(比如梯度过小,fp16范围不够,导致很小的梯度下溢成为0,而这样的小梯度占比很大,一句话来说就是下溢成0)。对于舍入误差,可以在更新权重时,将fp16的梯度转换为fp32,再更新fp32的权重,从而避免精度问题。对于梯度下溢,需要使用loss scale。
loss scale就是FWD计算出loss后,对loss放大若干倍,由于求导的链式法则,放大的若干倍同样会传导到fp16梯度,这样fp16梯度就不会产生梯度下溢。在更新权重时,将fp16的梯度转换为fp32,同时进行unscale。
刚才说到bf16有一点点特殊,我们看相应的代码:(Megatron中的arguments.py)
注意到如果使用bf16,那么会强行设置accumulate_allreduce_grads_in_fp32=True,这与上面Megatron offical readme截图(Distributed Optimizer)表格中的第二行【bf16 param, fp32 grads】相对应。具体过程应该是(not for sure, hope for discuss):
accumulate_allreduce_grads_in_fp32:If true, do the gradient accumulation and communication in fp32. from here
gradient accumulation:在若干次iteration中,每次都会反向得到一份梯度,将这若干次iteration得到的梯度进行累加、求平均,在最后一次iteration才更新权重。gradient accumulation与data parallel是等价的,gradient accumulation在时间维度上训练多个mini-batch,而data parallel在相同时间内将不同mini-batch放在不同的机器上训练,结果都是一样的。
参考:
聊聊梯度累加(Gradient Accumulation)
梯度累积算法
Hugging Face:Performing gradient accumulation with 🤗 Accelerate
这里找到一个为什么要将bf16与accumulate_allreduce_grads_in_fp32绑定的issue,里面提到“We found this to lead to more stable training before, but you could also try to perform the all-reduce in bf16
(it might hurt convergence but will be faster).”
参考:
- 图解大模型训练之:数据并行下篇( DeepSpeed ZeRO,零冗余优化)
- 图解大模型训练系列之:Megatron源码解读3,分布式混合精度训练
- NVIDIA Docs Hub:Train With Mixed Precision
- 全网最全-混合精度训练原理
量化分析
transformer结构详解
LLM中的transformer一般是decoder-only结构,所以下面的transformer block主要是decoder,但是与Vanilla Transformer中的decoder不同的是,这里没有了cross-attn,因此结构看起来反而有点像encoder(但不是,因为有casual mask)。
下面图中的Transformer,没有上kv-cache、GQA等优化,这部分后面会分析。其中,参数量 Φ \Phi Φ表示有多少个参数;中间激活值 A A A的单位是Bytes,主要参考的是分析transformer模型的参数量、计算量、中间激活、KV cache
在Reducing Activation Recomputation in Large Transformer Models 4.1节中也对transformer激活值进行了一个分析,但是该论文中,self-attention block部分softmax之前没有加mask,上图中添加了mask,具体在Attention部分stage SA_3,其中mask由于是整个transformer共享的,所以就省略了, Q K T QK^T QKT的乘积被mask原地修改,所以 w b a s 2 wbas^2 wbas2也省略了,这样激活值与原论文中仍然是一样的。
KV cache对参数量、计算量、激活值的影响
关于KV Cache的来龙去脉,Encoder Decoder和decoder Only架构训练和推理浅析中简单捋了一下。简单来说,kv cache在推理过程中使用,而且模型只能是decoder-only架构。由于自回归的方式逐token生成,self-attention部分必须使用casual mask,因此Q矩阵部分只需要计算最新token的q向量即可,K、V矩阵部分只需要拼接新token的k、v向量即可:
上面又重新回顾了一下kv cache。首先kv cache不会对参数量有影响,kv cache主要是用来减少不必要的计算的,显存因此也可能有相应的减少,上面只是一个示意图,中间省略了一些部分,详细的量化分析见下图,需要说明的有两点:
- kv cache使用场景是推理场景,LLM推理分为prefill阶段和decode阶段,prefill阶段创建kv-cache,decode阶段更新kv-cache。在输入prompt的这个prefill阶段中,with kv-cache和without kv-cache的计算量是相同的(显存占用由于分配kv-cache,可能with kv-cache会更多一点)。计算量的减少主要体现在decode阶段,因此下面的分析主要是针对单次decode阶段的,因此固定 s = = 1 s==1 s==1
- 下图中说的“相对于原来“指的是without kv-cache时,每次都输入之前所有的token,计算完整的attention-score方阵,因而此时的序列长度 s = s n ≤ s m s=s_n \le s_m s=sn≤sm。在最终分析时,取最大值 s = s m s=s_m s=sm进行比较,对应decode阶段的最后一个token的生成过程,有的博客可能会将输入序列长度(prompt长度)和输出序列长度分开,这里合起来了,注意区别。
原来(without kv-cache) | 现在(with kv-cache) | 变化 | |
---|---|---|---|
参数量 | 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l | 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l | 不变 |
中间激活 | 2 b s h + ( 34 b s m h + 5 b a s m 2 ) l 2bsh+(34bs_mh+5bas_m^2)l 2bsh+(34bsmh+5basm2)l | 2 b s h + ( 30 b h + 4 b s m h + 5 b a s m ) l 2bsh+(30bh+4bs_mh+5bas_m)l 2bsh+(30bh+4bsmh+5basm)l | 减少了 ( 30 b h ( s m − 1 ) + 5 b a s m ( s m − 1 ) ) l (30bh(s_m-1)+5bas_m(s_m-1))l (30bh(sm−1)+5basm(sm−1))l,原来中间激活是最长序列长度 s m s_m sm的二次方,现在随着 s m s_m sm线性增长 |
计算量 | ( 24 h + 4 s m ) b s m h l + 2 b s m h V (24h+4s_m)bs_mhl+2bs_mhV (24h+4sm)bsmhl+2bsmhV | ( 24 h + 4 s m ) b h l + 2 b h V (24h+4s_m)bhl+2bhV (24h+4sm)bhl+2bhV | 减少了 ( 24 h + 4 s m ) b h l ( s m − 1 ) + 2 b h V ( s m − 1 ) (24h+4s_m)bhl(s_m-1)+2bhV(s_m-1) (24h+4sm)bhl(sm−1)+2bhV(sm−1),原来计算量是最长序列长度 s m s_m sm的二次方,现在随着 s m s_m sm线性增长 |
code: from 【手撕LLM-KVCache】显存刺客的前世今生–文末含代码
# author: xiaodongguaAIGC
# KV-Cache + Generation + decoder import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLMD = 128 # single-head-dim
V = 64 # vocab_sizeclass xiaodonggua_kv_cache(torch.nn.Module):def __init__(self, D, V): super().__init__()self.D = Dself.V = Vself.Embedding = torch.nn.Embedding(V,D)self.Wq = torch.nn.Linear(D,D) self.Wk = torch.nn.Linear(D,D) self.Wv = torch.nn.Linear(D,D)self.lm_head = torch.nn.Linear(D,V) # LM_headself.cache_K = self.cache_V = None # initialdef forward(self,X):X = self.Embedding(X)Q,K,V = self.Wq(X),self.Wk(X),self.Wv(X)print("input_Q:", Q.shape)print("input_K:", K.shape)print("input_V:", V.shape)# Easy KV_Cacheif self.cache_K == None: # first timeself.cache_K = Kself.cache_V = Velse:self.cache_K = torch.cat((self.cache_K, K), dim = 1)self.cache_V = torch.cat((self.cache_V, V), dim = 1)K = self.cache_KV = self.cache_Vprint("cache_K:", self.cache_K.shape)print("cache_V:", self.cache_K.shape)# ignore proj/MLP/scaled/mask/multi-head when calculate Attentionattn =Q@K.transpose(1,2)@V# outputoutput=self.lm_head(attn)return outputmodel = xiaodonggua_kv_cache(D,V)# 创建数据、不使用tokenizer
X = torch.randint(0, 64, (1,10))
print(X.shape)for i in range(4):print(f"\nGeneration {i} step input_shape: {X.shape}:")output = model.forward(X) print(output.shape)next_token = torch.argmax(F.softmax(output, dim = -1),-1)[:,-1]print(next_token.shape)X = next_token.unsqueeze(0)
reference and more reading:
【大模型理论篇】Transformer KV Cache原理深入浅出
大模型推理优化技术-KV Cache
一文读懂KVCache
MQA和GQA对显存占用的影响
在实际推理场景中,kv-cache已经是默认的选项。但是kv-cache是很占显存的,占用显存为 2 w k v b s m ( a h a ) l 2 w_{kv} b s_m (a h_a) l 2wkvbsm(aha)l(其中 h = a ∗ h a h=a * h_a h=a∗ha),后面会有case study分析。针对kv cache的各种优化层出不穷,下面的参考中有几篇博客总结了一下对kv cache的各种优化,简单来说,从上面的显存分析入手,有以下几种优化方法:
- 针对attention 窗口(或者叫做context,上下文,或者当作最长序列长度 s m s_m sm) s m s_m sm的优化,比如window attention,sparse attention,StreamingLLM
- 针对注意力头 a a a的优化,比如MQA,GQA共享kv-cache(sharing)
- 针对层数 l l l的优化,比如YOCO层间共享kv-cache(sharing)
- 针对精度 w k v w_{kv} wkv的优化,比如kv-cache采用int8量化
- 针对内存分配的优化,减少内存碎片等,比如PagedAttention
- 其他优化。。。
其中MQA/GQA在LLM中广泛使用,比如Llama2中就使用到了GQA。下面简单分析一下。
GQA方法很简单,原来MHA中每个q向量对应一个k向量和v向量,进行attention计算;现在好几个q向量对应(或者说共享)一个k向量和v向量,这“好几个q向量”构成一组,一共有g组,每组就有 a g \frac{a}{g} ga个q向量。如果g=1,那么就是MQA,a个q向量构成一组,共享一个k、v向量;如果g=a,那么就是MHA,每个q向量构成一组,对应一个k、v向量。实际场景中,往往g=8,比如推理场景中单卡放不下,正好单机八卡,每张卡对应一组q向量。
虽然MQA/GQA是针对推理过程中kv-cache的优化,但是在训练中也能用,也能省显存。下面对GQA在推理场景中的使用(with kv_cache)进行一个量化分析。
因为GQA只影响self-attention计算部分,因此其他部分省略,下面的表格也是只分析这个变化的部分。可以看出,由于kv-cache在长序列的情况下会占用很多显存,GQA针对中间激活的优化与序列长度相关,实际上GQA对中间激活的优化就是将kv-cache变为原来的 g a \frac{g}{a} ag倍。
原来(MHA)-现在(GQA) | 说明 | |
---|---|---|
参数量 | [ 3 ( h 2 + h ) ] l − [ ( 2 g a + 1 ) ( h 2 + h ) ] l = 2 ( 1 − g a ) ( h 2 + h ) l \left [3(h^2+h) \right ]l - \left [ (\frac{2g}{a}+1)(h^2+h) \right ]l=2(1-\frac{g}{a})(h^2+h)l [3(h2+h)]l−[(a2g+1)(h2+h)]l=2(1−ag)(h2+h)l | |
中间激活 | [ w b s h + 2 w k v b s m h ] l − [ w b s h + 2 w k v b s m h × g a ] l = 2 w k v b s m h l ( 1 − g a ) \left [ wbsh+2w_{kv}bs_mh \right]l - \left [ wbsh + 2w_{kv}bs_mh \times\frac{g}{a} \right ]l = 2w_{kv}bs_mhl(1-\frac{g}{a}) [wbsh+2wkvbsmh]l−[wbsh+2wkvbsmh×ag]l=2wkvbsmhl(1−ag) | 尤其当长序列( b s m bs_m bsm较大),大模型( h l hl hl较大)时,前面系数较大,整体激活减少比较可观 |
计算量 | $\left [ 6bsh^2 \right ]l - \left [ 2bsh^2 (\frac{2g}{a}+1) \right ] l = 4bsh^2l(1-\frac{g}{a}) \overset{s=1}{=} 4bh^2l(1-\frac{g}{a}) $ |
在训练场景中,同样给出量化分析。需要说明的是,上述分析是在推理场景+kv_cache+GQA的情况下进行的分析,下面公式是针对的是训练场景+GQA。
code: from MHA,MQA,GQA注意力
import torch
import torch.nn as nnclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_heads# attention weightsself.wq = nn.Linear(embed_dim, embed_dim)self.wk = nn.Linear(embed_dim, num_groups * self.head_dim)self.wv = nn.Linear(embed_dim, num_groups * self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def split_heads(self, x: torch.Tensor, num_groups=None):# n == num_heads or num_groupsx = x.view(x.size(0), x.size(1), -1, self.head_dim) # (batch_size, seq_len, n, head_dim)batch_size, seq_len, n, head_dim = x.size()if num_groups is not None:x = x.unsqueeze(dim=2)x = x.expand(size=(batch_size, seq_len, self.num_heads // num_groups, n, head_dim))x = x.reshape(batch_size, seq_len, self.num_heads, head_dim)x = x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)return xdef merge_heads(self, x: torch.Tensor):""":param x: (batch_size, num_heads, seq_len, head_dim)"""x = x.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len, num_heads, head_dim)x = x.view(x.size(0), x.size(1), -1) # ( batch_size, seq_len, embed_dim)return xdef forward(self, hidden_states: torch.Tensor, causal_mask=None):q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 分割注意力头q = self.split_heads(q)k = self.split_heads(k, num_groups=self.num_groups)v = self.split_heads(v, num_groups=self.num_groups)# 注意力计算attn_weights = torch.matmul(q, k.transpose(-1, -2)) / torch.tensor(k.size(-1), dtype=q.dtype)# causal maskmask_value = torch.finfo(attn_weights.dtype).minif causal_mask is None:seq_len = hidden_states.size(1)causal_mask = torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.bool))attn_weights = torch.where(causal_mask, attn_weights, mask_value)# 归一化attn_weights = torch.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, v)# 合并注意力头attn_output = self.merge_heads(attn_output)attn_output = self.wo(attn_output)return attn_output
参考:
大模型百倍推理加速之KV cache篇
LLM(二十):漫谈 KV Cache 优化方法,深度理解 StreamingLLM
[KV Cache优化]🔥MQA/GQA/YOCO/CLA/MLKV笔记: 层内和层间KV Cache共享
大模型推理加速:KV Cache 和 GQA
case study
我们以GPT和Llama为例,进行case study。
关于参数量的分析
GPT-3
GPT-3模型结构就大致上面【transformer结构详解】中的结构,但是多了一个可学习的position embedding,包含 n c t x ∗ h n_{ctx} * h nctx∗h个参数,其中 n c t x = 2048 n_{ctx}=2048 nctx=2048,rectified这一列是加上这些参数后的参数量。
params | h | l | a | b | V from GPT-2 | calculated params= V h + ( 12 h 2 + 13 h ) l Vh+(12h^2+13h)l Vh+(12h2+13h)l | rectified |
---|---|---|---|---|---|---|---|
GPT-3 Small: 125M | 768 | 12 | 64 | 0.5M | 50257 | 123651840 ≈ \approx ≈ 123.7M | 125224704 ≈ \approx ≈ 125.2M |
GPT-3 Medium: 350M | 1024 | 24 | 64 | 0.5M | 50257 | 353772544 $\approx$353.8M | 355869696 ≈ \approx ≈ 355.9M |
GPT-3 Large: 760M | 1536 | 24 | 96 | 0.5M | 50257 | 757151232 ≈ \approx ≈ 757.1M | 760296960 ≈ \approx ≈ 760.3M |
GPT-3 2.7B | 2560 | 32 | 80 | 1M | 50257 | 2646305280 ≈ \approx ≈ 2.64B | 2651548160 ≈ \approx ≈ 2.65B |
GPT-3 6.7B | 4096 | 32 | 128 | 2M | 50257 | 6650007552 ≈ \approx ≈ 6.65B | 6658396160 ≈ \approx ≈ 6.67B |
GPT-3 13B | 5140 | 40 | 128 | 2M | 50257 | 12942401780 ≈ \approx ≈ 12.94B | 12952928500 ≈ \approx ≈ 12.95B |
GPT-3 175B | 12288 | 96 | 128 | 3.2M | 50257 | 174579068928 ≈ \approx ≈ 174.58B | 174604234752 ≈ \approx ≈ 174.60B |
说明:
- GPT-3词表大小V在论文中没找到,所以用的GPT-2的词表大小,这里论文中是提到的
more relative reading:
- How does GPT-3 spend its 175B parameters?
Llama 1: LLaMa: Open and Efficient Foundation Language Models
模型结构:from hugging face transformers LLaMA
论文中说,该模型与Vanilla Transformer有三处区别:
-
Pre-normalization and RMSNorm
原始Transformer中使用post-norm居多,后来使用pre-norm居多,而且往往在FFN之前也加一个norm。尤其在大模型中,可能在通过LN之后MHA之前,Q和K还要加上旋转位置编码。
参考:【重新了解Transformer模型系列_1】PostNorm/PreNorm的差别
-
SwiGLU activation function
SwiGLU激活函数不太像传统的ReLU等激活函数那样简单,比如ReLU都不带参数,而SwiGLU乍一看上去不明觉厉,实际上将SwiGLU理解成对传统FFM的替换,感觉更合适一些。直接看公式有点懵,看图更容易理解,下面是FFM和SwiGLU的对比
SwiGLU写成公式就是 S w i G L U ( x ) = [ S i G U ( g a t e _ p r o j ( x ) ) ⊙ u p _ p r o j ( x ) ] × d o w n _ p r o j ( x ) SwiGLU(x) = \left [ SiGU \left( gate\_proj(x) \right) \odot up\_proj(x) \right] \times down\_proj(x) SwiGLU(x)=[SiGU(gate_proj(x))⊙up_proj(x)]×down_proj(x),其中可能有点困惑的是这个 8 h 3 \frac{8h}{3} 38h是怎么来的,实际上就是为了左右这两个结构的参数量相等: 2 × h × 4 h ≡ 2 × h × 8 h 3 + 8 h 3 × h 2 \times h \times 4h \equiv 2 \times h \times \frac{8h}{3} + \frac{8h}{3} \times h 2×h×4h≡2×h×38h+38h×h
-
Rotary Embedding
下面是模型配置,验证一下前面推出来的参数量相关的公式能否对上:
params | h | l | a | b | V | intermediate_size | calculated params= 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l |
---|---|---|---|---|---|---|---|
6.7B | 4096 | 32 | 32 | 4M | 32K | 11008 | 6706298880 ≈ \approx ≈ 6.71B |
13.0B | 5120 | 40 | 40 | 4M | 32K | 13824 | 12913254400 ≈ \approx ≈ 12.91B |
32.5B | 6656 | 60 | 52 | 4M | 32K | 17920 | 32328857600 ≈ \approx ≈ 32.33B |
65.2B | 8192 | 80 | 64 | 4M | 32K | 22016 | 64957317120 ≈ \approx ≈ 64.96B |
每次总是差一点,但是差的不多,差在了哪里呢?MLP部分,理论上intermediate_size= 8 h 3 \frac{8h}{3} 38h,但是实际上可能会比这个值大一些,往往向上取到256、512、1024等的倍数,对矩阵乘法性能更好,因此来修正一下参数量、计算量、激活值的量化分析:
重新计算一下,这次参数量就很接近了
params | h | l | a | b | V | intermediate_size | calculated params= 2 V h + ( 4 h + 4 + 3 I ) h l 2Vh+(4h+4+3I)hl 2Vh+(4h+4+3I)hl |
---|---|---|---|---|---|---|---|
6.7B | 4096 | 32 | 32 | 4M | 32K | 11008 | 6738673664 ≈ \approx ≈ 6.74B |
13.0B | 5120 | 40 | 40 | 4M | 32K | 13824 | 13016268800 ≈ \approx ≈ 13.02B |
32.5B | 6656 | 60 | 52 | 4M | 32K | 17920 | 32529735680 ≈ \approx ≈ 32.53B |
65.2B | 8192 | 80 | 64 | 4M | 32K | 22016 | 65286963200 ≈ \approx ≈ 65.29B |
Llama 2: Llama 2: Open Foundation and Fine-Tuned Chat Models
Llama2在模型结构方面与Llama1相差不大,只是将MHA替换为GQA,将attention的context length从2k提升到4k。下面是Llama2的模型配置
config | h | l | a | b | V | intermediate_size | MHA or GQA | calculated params= 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l | calculated params= 2 V h + ( 4 h + 4 + 3 I ) h l 2Vh+(4h+4+3I)hl 2Vh+(4h+4+3I)hl |
---|---|---|---|---|---|---|---|---|---|
7B, config | 4096 | 32 | 32 | 4M | 32K | 11008 | MHA | 6706298880 ≈ \approx ≈ 6.71B | 6738673664 ≈ \approx ≈ 6.74B |
13B, config | 5120 | 40 | 40 | 4M | 32K | 13824 | MHA | 12913254400 ≈ \approx ≈ 12.91B | 13016268800 ≈ \approx ≈ 13.02B |
至于70B的config(h=8192, l=80, a=64, b=4M, V=32K, intermediate_size=28672, g=8)使用了group=8的GQA,只有attention部分的参数量会发生一些变化,调整公式后,分别计算一下:
- calculated params= 2 V h + [ 10 h 2 + 11 h + 2 g a ( h 2 + h ) ] l 2Vh+\left[ 10h^2 + 11h + \frac{2g}{a}(h^2+h)\right] l 2Vh+[10h2+11h+a2g(h2+h)]l = 5556092928 ≈ \approx ≈ 55.56B,相差较大
- llama calculated params= 2 V h + [ ( 2 + 2 g a ) h 2 + 4 h + 3 h I ] l 2Vh + \left [ (2+\frac{2g}{a}) h ^ 2 + 4h + 3hI \right ] l 2Vh+[(2+a2g)h2+4h+3hI]l = 68977950720 ≈ \approx ≈ 68.98B,比较接近了
因此,对于transformer而言,
- 如果MLP是传统FFN那样的结构,calculated params= 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l
- 如果attention部分使用了GQA,则calculated params= 2 V h + [ 10 h 2 + 11 h + 2 g a ( h 2 + h ) ] l 2Vh+\left[ 10h^2 + 11h + \frac{2g}{a}(h^2+h)\right] l 2Vh+[10h2+11h+a2g(h2+h)]l
- 如果MLP是SwiGLU那样的结构,calculated params= 2 V h + ( 4 h + 4 + 3 I ) h l 2Vh+(4h+4+3I)hl 2Vh+(4h+4+3I)hl
- 如果attention部分使用了GQA,则calculated params= 2 V h + [ ( 2 + 2 g a ) h 2 + 4 h + 3 h I ] l 2Vh + \left [ (2+\frac{2g}{a}) h ^ 2 + 4h + 3hI \right ] l 2Vh+[(2+a2g)h2+4h+3hI]l
但是总的来说,transformer的复杂度还是 O ( h 2 l ) O(h^2l) O(h2l)级别的
more relative reading:
“Mastering Llama Math (Part-1): A Step-by-Step Guide to Counting Parameters in Llama-2”
LLM - Transformer && LLaMA2 结构分析与 LoRA 详解
Llama 3: The Llama 3 Herd of Models
Llama3的改进相对于Llama2和Llama1,主要体现在使用了更高质量的数据和更大规模的训练,模型结构基本没变。下面是模型配置,
config | h | l | a | b | V | intermediate_size | GQA group | calculated params= 2 V h + [ ( 2 + 2 g a ) h 2 + 4 h + 3 h I ] l 2Vh + \left [ (2+\frac{2g}{a}) h ^ 2 + 4h + 3hI \right ] l 2Vh+[(2+a2g)h2+4h+3hI]l |
---|---|---|---|---|---|---|---|---|
8B, config | 32 | 4096 | 32 | 4M->8M->16M | 128K | 14336 | 8 | 8028422144 ≈ \approx ≈ 8.03B |
70B, config | 80 | 8192 | 64 | 4M->8M->16M | 128K | 28672 | 8 | 70550814720 ≈ \approx ≈ 70.55B |
405B | 126 | 16384 | 128 | 4M->8M->16M | 128K | 53248 | 8 | 405849112576 ≈ \approx ≈ 405.85B |
参考:
LLaMa-1/2/3 原理+源码——拆解 (KV-Cache, RoPE, RMSNorm, GQA, SwiGLU)
关于激活的分析
前面总说中间激活可能很占显存,我们来分析几个case。
GPT-3
config | h | l | a | b | s | V from GPT-2 | activation ≈ ( 34 b s h + 5 b a s 2 ) l \approx (34bsh+5bas^2)l ≈(34bsh+5bas2)l | activation (with GQA) ≈ [ ( 28 + 4 g a ) b s h + 5 b a s 2 ] l \approx \left [ (28+\frac{4g}{a})bsh+5bas^2\right]l ≈[(28+a4g)bsh+5bas2]l |
---|---|---|---|---|---|---|---|---|
GPT-3 Small: 125M | 768 | 12 | 64 | 1 | 2048 | 50257 | 15972.0MB ≈ 67.0 × 2 Φ \approx 67.0 \times 2\Phi ≈67.0×2Φ | 15873.0MB ≈ 66.58 × 2 Φ \approx 66.58 \times 2\Phi ≈66.58×2Φ |
GPT-3 Medium: 350M | 1024 | 24 | 64 | 1 | 2048 | 50257 | 32352.0MB ≈ 48.5 × 2 Φ \approx 48.5 \times 2\Phi ≈48.5×2Φ | 32088.0 ≈ 48.1 × 2 Φ \approx 48.1 \times 2\Phi ≈48.1×2Φ |
GPT-3 Large: 760M | 1536 | 24 | 96 | 1 | 2048 | 50257 | 48528.0 MB ≈ 33.5 × 2 Φ \approx 33.5 \times 2\Phi ≈33.5×2Φ | 48120.0MB ≈ 33.2 × 2 Φ \approx 33.2 \times 2\Phi ≈33.2×2Φ |
GPT-3 2.7B | 2560 | 32 | 80 | 1 | 2048 | 50257 | 55.3GB ≈ 11.0 × 2 Φ \approx 11.0 \times 2\Phi ≈11.0×2Φ wrong | 54.4GB ≈ 10.82 × 2 Φ \approx 10.82 \times 2\Phi ≈10.82×2Φ |
GPT-3 6.7B | 4096 | 32 | 128 | 1 | 2048 | 50257 | 88.5GB ≈ 7.10 × 2 Φ \approx 7.10 \times 2\Phi ≈7.10×2Φ | 87.1GB ≈ 6.98 × 2 Φ \approx 6.98 \times 2\Phi ≈6.98×2Φ |
GPT-3 13B | 5140 | 40 | 128 | 1 | 2048 | 50257 | 113.3GB ≈ 4.68 × 2 Φ \approx 4.68 \times 2\Phi ≈4.68×2Φ | 111.1GB ≈ 4.59 × 2 Φ \approx 4.59 \times 2\Phi ≈4.59×2Φ |
GPT-3 175B | 12288 | 96 | 128 | 1 | 2048 | 50257 | 316.5GB ≈ 0.97 × 2 Φ \approx 0.97 \times 2\Phi ≈0.97×2Φ | 303.6GB ≈ 0.93 × 2 Φ \approx 0.93 \times 2\Phi ≈0.93×2Φ |
GPT-3 175B | 12288 | 96 | 128 | 8 | 2048 | 50257 | 2532.0GB ≈ 7.77 × 2 Φ \approx 7.77 \times 2\Phi ≈7.77×2Φ | 2428.5GB ≈ 7.45 × 2 Φ \approx 7.45 \times 2\Phi ≈7.45×2Φ |
GPT-3 175B | 12288 | 96 | 128 | 64 | 2048 | 50257 | 19.78TB ≈ 62.14 × 2 Φ \approx 62.14 \times 2\Phi ≈62.14×2Φ | 18.97TB ≈ 59.60 × 2 Φ \approx 59.60 \times 2 \Phi ≈59.60×2Φ |
Llama-2:
config | h | l | a | b | s | V | intermediate_size | GQA: group | activation (with GQA) ≈ [ ( 13 + 4 g a ) b s h + 5 b a s 2 + 6 b s I ] l \approx \left [ (13+\frac{4g}{a})bsh+5bas^2 + 6bsI\right]l ≈[(13+a4g)bsh+5bas2+6bsI]l |
---|---|---|---|---|---|---|---|---|---|
7B, config | 4096 | 32 | 32 | 1 | 4096 | 32K | 11008 | 32(MHA) | 96.6GB ≈ 7.4 × 2 Φ \approx 7.4 \times 2\Phi ≈7.4×2Φ |
13B, config | 5120 | 40 | 40 | 1 | 4096 | 32K | 13824 | 40(MHA) | 150.9GB ≈ 6.2 × 2 Φ \approx 6.2 \times 2\Phi ≈6.2×2Φ |
70B, config | 8192 | 80 | 64 | 1 | 4096 | 32K | 28672 | 8 | 486.25GB ≈ 3.7 × 2 Φ \approx 3.7 \times 2\Phi ≈3.7×2Φ |
70B, config | 8192 | 80 | 64 | 8 | 4096 | 32K | 28672 | 8 | 3890.0GB ≈ 29.8 × 2 Φ \approx 29.8 \times 2\Phi ≈29.8×2Φ |
70B, config | 8192 | 80 | 64 | 64 | 4096 | 32K | 28672 | 8 | 30.39TB ≈ 238.7 × 2 Φ \approx 238.7 \times 2\Phi ≈238.7×2Φ |
由于前面分析过,intermediate_size往往会略微大于 8 h 3 \frac{8h}{3} 38h,因此根据前面分析的llama结构,重新推导一下激活的计算公式,这里省略了。
可以看出,当大batch、长序列的情况下,中间激活可以是模型参数所占显存的很多倍,即使使用了GQA。
上面都是在训练场景下的激活值分析,在推理阶段中,可以使用kv-cache减少模型计算量,同时中间激活也大幅度减少,kv-cache的大小为 2 w k v b s m h 2w_{kv}bs_mh 2wkvbsmh(单层),我们也来量化分析一下(假设 w k v w_{kv} wkv=2,且s=1,推理context长度最后一个token的情况,即最坏情况)
config | b | s m s_m sm | h | a | l | kv_cache size= 2 w k v b s m h l 2w_{kv}bs_mhl 2wkvbsmhl | without kv-cache activation ≈ ( 34 b s m h + 5 b a s m 2 ) l \approx (34bs_mh+5bas_m^2)l ≈(34bsmh+5basm2)l | with kv-cache activation ≈ ( 30 b h + 4 b s m h + 5 b a s m ) l \approx (30bh+4bs_mh+5bas_m)l ≈(30bh+4bsmh+5basm)l |
---|---|---|---|---|---|---|---|---|
GPT-3 Small: 125M | 1 | 2048 | 768 | 64 | 12 | 72MB ≈ 0.30 × 2 Φ \approx 0.30 \times 2\Phi ≈0.30×2Φ | 15972.0MB ≈ 67.0 × 2 Φ \approx 67.0 \times 2\Phi ≈67.0×2Φ | 79.8MB ≈ 0.33 × 2 Φ \approx 0.33 \times 2\Phi ≈0.33×2Φ |
GPT-3 Medium: 350M | 1 | 2048 | 1024 | 64 | 24 | 192MB ≈ 0.29 × 2 Φ \approx 0.29 \times 2\Phi ≈0.29×2Φ | 32352.0MB ≈ 48.5 × 2 Φ \approx 48.5 \times 2\Phi ≈48.5×2Φ | 207.7MB ≈ 0.31 × 2 Φ \approx 0.31 \times 2\Phi ≈0.31×2Φ |
GPT-3 Large: 760M | 1 | 2048 | 1536 | 96 | 24 | 288MB ≈ 0.20 × 2 Φ \approx 0.20 \times 2\Phi ≈0.20×2Φ | 48528.0MB ≈ 33.5 × 2 Φ \approx 33.5 \times 2\Phi ≈33.5×2Φ | 311.6MB ≈ 0.21 × 2 Φ \approx 0.21 \times 2\Phi ≈0.21×2Φ |
GPT-3 2.7B | 1 | 2048 | 2560 | 80 | 32 | 640MB ≈ 0.12 × 2 Φ \approx 0.12 \times 2\Phi ≈0.12×2Φ | 55.3GB ≈ 11.0 × 2 Φ \approx 11.0 \times 2\Phi ≈11.0×2Φ | 667.3MB ≈ 0.13 × 2 Φ \approx 0.13 \times 2\Phi ≈0.13×2Φ |
GPT-3 6.7B | 1 | 2048 | 4096 | 128 | 40 | 1280MB ≈ 0.1 × 2 Φ \approx 0.1 \times 2\Phi ≈0.1×2Φ | 110.6GB ≈ 8.9 × 2 Φ \approx 8.9 \times 2 \Phi ≈8.9×2Φ | 1334.7MB ≈ 0.1 × 2 Φ \approx 0.1 \times 2 \Phi ≈0.1×2Φ |
GPT-3 13B | 1 | 2048 | 5140 | 128 | 96 | 3.76GB ≈ 0.15 × 2 Φ \approx 0.15 \times 2\Phi ≈0.15×2Φ | 272.0GB ≈ 11.2 × 2 Φ \approx 11.2 \times 2\Phi ≈11.2×2Φ | 3.89GB ≈ 0.16 × 2 Φ \approx 0.16 \times 2\Phi ≈0.16×2Φ |
GPT-3 175B | 1 | 2048 | 12288 | 128 | 96 | 9.0GB ≈ 0.02 × 2 Φ \approx 0.02 \times 2\Phi ≈0.02×2Φ | 316.5GB $\approx 0.97\times 2\Phi $ | 9.15GB ≈ 0.03 × 2 Φ \approx 0.03 \times 2\Phi ≈0.03×2Φ |
GPT-3 175B | 8 | 2048 | 12288 | 128 | 96 | 72.0GB ≈ 0.22 × 2 Φ \approx 0.22 \times 2\Phi ≈0.22×2Φ | 2532.0GB ≈ 7.77 × 2 Φ \approx 7.77 \times 2\Phi ≈7.77×2Φ | 73.2GB ≈ 0.22 × 2 Φ \approx 0.22 \times 2\Phi ≈0.22×2Φ |
GPT-3 175B | 64 | 2048 | 12288 | 128 | 96 | 576.0GB ≈ 1.77 × 2 Φ \approx 1.77 \times 2\Phi ≈1.77×2Φ | 19.78TB ≈ 62.1 × 2 Φ \approx 62.1 \times 2\Phi ≈62.1×2Φ | 585.6GB ≈ 1.80 × 2 Φ \approx 1.80 \times 2\Phi ≈1.80×2Φ |
可以看出在推理时,kv-cache大幅度减少了中间激活。而且使用了kv-cache以后,kv-cache在激活中占据了绝大部分的比例,kv-cache甚至可以超过模型所占内存。
关于计算量的分析
量化分析模型的计算量,主要是为了预估模型训练时间。根据前面的分析,一个FWD+BWD的iteration训练过程中,计算量FLOPs= 6 × Φ × 输入 t o k e n s 数量 6 \times \Phi \times 输入tokens数量 6×Φ×输入tokens数量,因此可以大致估计训练时间= 6 × Φ × 输入 t o k e n s 数量 G P U 数量 × G P U 算力 ( f l o p s ) × M F U \frac{6 \times \Phi \times 输入tokens数量}{GPU数量\times GPU算力(flops) \times MFU} GPU数量×GPU算力(flops)×MFU6×Φ×输入tokens数量。
其他说明
1. LayerNorm的计算
LayerNorm的计算过程见pytorch LayerNorm参数详解,计算过程,总结一下就是:
- 比如输入是
[b,s,h]
,LN的normalized_shape=[h]
,此时就是对每一个大小为h
的向量分别进行归一化(一共b*s
个) - 然后如果LN的
elementwise_affine=True
,就需要对每个大小为h
的向量elementwise的乘上 γ : [ h ] \gamma: [h] γ:[h],再elementwise的加上 β : [ h ] \beta:[h] β:[h], γ \gamma γ和 β \beta β就是该LN层的两个可学习的参数。如果LN的elementwise_affine=False
,则只会进行第一步的归一化,不会进行第二步的affine
一个有趣的问题是,Transformer中的LayerNorm可以并行吗?
关键词: Welford online Algorithm,当一个集合新增加一个元素 x N x_N xN的时候,可以通过前N-1个样本的corrected sum of squares( ∑ i = 1 N − 1 ( x i − x ˉ ) 2 \sum_{i=1}^{N-1}(x_i-\bar{x})^2 ∑i=1N−1(xi−xˉ)2),计算出前N个样本的corrected sum of squares,从而只需要one pass就可以完成LN的计算(之前navie的方法是two pass)
2. 关于dropout的位置
一共(可能)在有四个地方有dropout:
- 在PositionalEmbedding中有一个dropout:
dropout(x + PositionEmbedding(x))
,不过好像LLM现在使用旋转位置编码RoPE多一些,在计算attention之前在Q和K上加上RoPE,一开始输入的embedding不加PositionalEmbedding了 - 在softmax计算得到的attention score之后有一个droput: d r o p o u t ( s o f t m a x ( Q K T s c a l e + c a s u a l _ m a s k ) ) dropout( softmax(\frac{QK^T}{scale}+casual\_mask) ) dropout(softmax(scaleQKT+casual_mask))
- 在sublayer(Attention和MLP)计算完之后,各有一个dropout:
x+dropout(sublayer(norm(x)))
总结
transformer的参数量的复杂度是 O ( h 2 l ) O(h^2l) O(h2l)级别的,粗略估计可以认为是 12 h 2 l 12h^2l 12h2l或者 ( 4 h + 3 I ) h l (4h+3I)hl (4h+3I)hl,如果要详细分析,就要看一看每个部分的结构,是否使用了bias,使用的不同优化,比如:
- 如果MLP是传统FFN那样的结构,calculated params= 2 V h + ( 12 h 2 + 13 h ) l 2Vh+(12h^2+13h)l 2Vh+(12h2+13h)l
- 如果attention部分使用了GQA,则calculated params= 2 V h + [ 10 h 2 + 11 h + 2 g a ( h 2 + h ) ] l 2Vh+\left[ 10h^2 + 11h + \frac{2g}{a}(h^2+h)\right] l 2Vh+[10h2+11h+a2g(h2+h)]l
- 如果MLP是SwiGLU那样的结构,calculated params= 2 V h + ( 4 h + 4 + 3 I ) h l 2Vh+(4h+4+3I)hl 2Vh+(4h+4+3I)hl
- 如果attention部分使用了GQA,则calculated params= 2 V h + [ ( 2 + 2 g a ) h 2 + 4 h + 3 h I ] l 2Vh + \left [ (2+\frac{2g}{a}) h ^ 2 + 4h + 3hI \right ] l 2Vh+[(2+a2g)h2+4h+3hI]l
对transformer中间激活的分析要分训练场景和推理场景
- 在训练场景中,中间激活可以是模型参数所占显存的很多倍,尤其在大batch、长序列的情况下。
- 中间激活值所占显存粗略估计可以认为是 ( 34 b s h + 5 b a s 2 ) l (34bsh+5bas^2)l (34bsh+5bas2)l或者 ( 17 b s h + 5 b a s 2 + 6 b s I ) l (17bsh+5bas^2+6bsI)l (17bsh+5bas2+6bsI)l,可以看出与输入token数量(batch和seq_len)、隐藏层维度、头数、intermediate_size、层数相关,因此相对参数量的分析稍微复杂一点。
- 在推理场景中,prefill阶段基本同训练场景,decode阶段每次输入的序列长度为1,而且默认使用kv-cache。由于使用kv-cache,中间激活相对于训练时的中间激活大幅度减小,但是在大batch、长序列的情况下,kv-cache的显存占用仍然可能超过模型参数的显存占用。还有一点需要注意,推理场景中kv-cache在中间激活中占据了绝大部分。
- 中间激活值所占显存粗略估计可以认为是 ( 30 b h + 4 b s m h + 5 b a s m ) l (30bh+4bs_mh+5bas_m)l (30bh+4bsmh+5basm)l或者 ( 13 b h + 4 b s m h + 5 b s m a + 6 b I ) l (13bh+4bs_mh+5bs_ma+6bI)l (13bh+4bsmh+5bsma+6bI)l
对transformer的计算量的分析比较简单,transformer中计算较为规整,计算量体现在若干个大块矩阵的乘法。一般量化分析计算量主要是为了预估模型训练时间,所以一般分析的不多(一般也没有机会训练大模型,如果训练普通规模的网络,尝试跑几个iteration就能估计)。