系列文章目录
大模型推理 & memory bandwidth bound (1) - 性能瓶颈与优化概述
大模型推理 & memory bandwidth bound (2) - Multi-Query Attention
大模型推理 & memory bandwidth bound (3) - MLA
文章目录
- 系列文章目录
- 前言
- 一、原理
- 1.低秩压缩 & 动机
- 2.矩阵吸收
- 3.RoPE解耦
- 4.计算优化
- 二、代码
- 1.优化前
- 2.优化后
- 总结
- 参考
前言
“MLA guarantees efficient inference through significantly compressing the Key-Value (KV) cache into a latent vector, while DeepSeekMoE enables training strong models at an economical cost through sparse computation.” —— 《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》
在上一篇,我们解读了《Fast Transformer Decoding: One Write-Head is All You Need》这篇文章,确认了MHA
在Decode
阶段memory bandwidth bound
的问题,分析了MQA
方法带来的优化。具体来说,MQA
的KV Cache
压缩成了原来 1 h \frac{1}{h} h1 ( h h h 为注意力头的个数),增加了计算强度,缓解了memory bandwidth bound
问题;同时,KV Cache
的减少意味着可以增加batch size
,也即可以同时处理更多的requests
。
今天要讨论的是Multi-head Latent Attention (MLA)
,它是另一种极致压缩KV Cache
的方法,在DeepSeek V2
中被提出,也是当前爆火的DeepSeek-V3/R1
模型的核心模块之一。
一、原理
1.低秩压缩 & 动机
如上图所示,不同于MHA
以及多头共用KV
的GQA
和MQA
,MLA
使用的是KV联合低秩压缩方法
,在推理时只需要存储压缩的这部分,也就是图中的Compressed Latent KV
。通过投影矩阵计算得到MLA
中的KV
(图中蓝色长条)。用公式表达如下:
c t K V = W D K V h t \textbf{c}_t^{KV} = W^{DKV}\textbf{h}_t ctKV=WDKVht
k t = W U K c t K V \textbf{k}_t = W^{UK}\textbf{c}_t^{KV} kt=WUKctKV
v t = W U V c t K V \textbf{v}_t = W^{UV}\textbf{c}_t^{KV} vt=WUVctKV
其中下标 t t t 表示第 t t t 个token, h t \textbf{h}_t ht 是注意力模块的输入, c t K V \textbf{c}_t^{KV} ctKV是低秩压缩的隐向量,也就是推理时需要被缓存的部分。 W D K V ∈ R d c × d W^{DKV} \in \mathbb{R}^{d_c \times d} WDKV∈Rdc×d 是 down-projection matrix(不知道怎么翻译),即降维投影矩阵,要求 d c < d d_c < d dc<d ; W U K , W U V ∈ R d h n h × d c W^{UK}, W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} WUK,WUV∈Rdhnh×dc 是 up-projection matrix,升维投影矩阵,满足 d c < d h n h d_c< d_hn_h dc<dhnh 。 d d d 是输入的维度, d c d_c dc 是Compressed Latent KV
的维度, n h n_h nh 是注意力头的个数, d h d_h dh 是每个注意力头的维度。
我们来猜一下设计MLA
的动机。想压缩KV Cache
是显而易见的,这一点和MQA
接近;MQA
相较于MHA
的劣势是有性能损失,而MLA
最后的形式等同于MHA
。所以我们斗胆下个结论:MLA
是MQA
和MHA
的杂交品种,要同时继承MQA
压缩KV cache
和MHA
多个注意力头的优异性能这两个优点。
2.矩阵吸收
但你有没有发现一个问题,尽管上图告诉你MLA
只需要缓存Compressed Latent KV
,但是它最后展开成MHA
了,这部分 keys 和 values 还是要占用显存,同时又多了一步低秩压缩,增加了内存操作。那么怎么解决这个问题呢?论文中给出了答案 —— 矩阵吸收。
我们先来梳理一下MLA
的计算过程,相应的公式按照神经网络中的表达方式(不同于上面或者论文中公式矩阵左乘的方式),比如nn.linear()
对应的表示为
y = x W ⊤ y= xW^\top y=xW⊤
那么 Q K ⊤ QK^\top QK⊤ 的一般表达式为
Q K ⊤ = ( x Q W Q ⊤ ) ( x K W K ⊤ ) ⊤ = ( x Q W Q ⊤ ) ( W K x K ⊤ ) = ( x Q W Q ⊤ W K ) x K ⊤ \begin{aligned} QK^\top & = (x_QW_Q^\top)(x_KW_K^\top)^\top \\ & = (x_QW_Q^\top)(W_Kx_K^\top) \\ & = (x_QW_Q^\top W_K)x_K^\top \end{aligned} QK⊤=(xQWQ⊤)(xKWK⊤)⊤=(xQWQ⊤)(WKxK⊤)=(xQWQ⊤WK)xK⊤
放到MLA场景中进行变量替换 W K → W U K W_K \rightarrow W^{UK} WK→WUK , x K → c K V x_K \rightarrow c^{KV} xK→cKV 得到
Q K ⊤ = ( x Q W Q ⊤ W U K ) c K V ⊤ QK^\top = (x_QW_Q^\top W^{UK}){c^{KV}}^\top QK⊤=(xQWQ⊤WUK)cKV⊤
注意这边偷懒了,因为 W U K , W U V ∈ R d h n h × d c W^{UK}, W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} WUK,WUV∈Rdhnh×dc ,实际是需要展开成 n h n_h nh 个头的。 由于矩阵乘法结合律,我们可以让括号中的作为一个整体,也就是将矩阵 W U K W^{UK} WUK 进行了吸收。为什么要这样做,我们等会解释。
接下来的计算输出
P = S o f t m a x ( Q K ⊤ d h ) P = Softmax(\frac{QK^\top}{\sqrt{d_h}}) P=Softmax(dhQK⊤)
O = L i n e a r ( P V ) = ( P ( x V W V ⊤ ) ) W O ⊤ \begin{aligned} O & = Linear(PV)\\ & = (P(x_V W_V ^\top))W_O^\top \end{aligned} O=Linear(PV)=(P(xVWV⊤))WO⊤
放到MLA
场景中进行变量替换得到
O = P c K V ( W U V ⊤ W O ⊤ ) O = Pc^{KV} ({W^{UV}} ^\top{W^O}^\top) O=PcKV(WUV⊤WO⊤)
同样的,这边的计算实际也要reshape成 n h n_h nh 个头。 因此对于V
,我们也可以使用矩阵吸收,只不过矩阵 W U V W^{UV} WUV 是在attention输出部分进行吸收。
这样的好处是缓存和计算都基于联合压缩 c K V c^{KV} cKV (KV cache
显存占用减少),计算形式上是KV
共用的MQA
(计算强度增加),但拥有MHA
多个注意力头这样更加优异的性能(这部分信息隐藏在矩阵 W U K W^{UK} WUK 和 W U V W^{UV} WUV 中)。
3.RoPE解耦
上述分析忽略了旋转位置编码RoPE
。考虑到RoPE
为矩阵 R m ∈ R d h × d h R_m \in \mathbb{R}^{d_h \times d_h} Rm∈Rdh×dh , m m m 表示位置索引, d h d_h dh 和上面一样,是每个头的维度,有 R m R n ⊤ = R m − n R_mR_n^\top=R_{m-n} RmRn⊤=Rm−n 。这边考虑单个头(没办法,不然不好表示) q m k n ⊤ q_m k_n^ \top qmkn⊤ 的计算, m m m 和 n n n 是位置索引。
q m k n ⊤ = ( x m W q ⊤ R m ⊤ ) ( x n W k ⊤ R n ⊤ ) ⊤ = x m ( W q ⊤ R m ⊤ R n W k ) x n ⊤ = x m ( W q ⊤ R n − m ⊤ W k ) x n ⊤ \begin{aligned} q_mk_n^\top & = (x_mW_q^\top R_m^ \top)(x_nW_k^\top R_n^ \top)^\top \\ & = x_m(W_q^\top R_m^ \top R_n W_k)x_n^\top \\ & = x_m(W_q^\top R_{n-m}^ \top W_k)x_n^\top \end{aligned} qmkn⊤=(xmWq⊤Rm⊤)(xnWk⊤Rn⊤)⊤=xm(Wq⊤Rm⊤RnWk)xn⊤=xm(Wq⊤Rn−m⊤Wk)xn⊤
这样一来,矩阵吸收就破产了,因为中间的 R n − m R_{n-m} Rn−m 是和位置差相关的,即随着位置差是变动的。
DeepSeek团队给出的解决方案是对RoPE
解耦,导致我们最终看到的MLA
架构如下图所示。简单解释一下,下图中上标带 R R R 的表示施加了RoPE
,而上标带 C C C 的则是未施加RoPE
的部分。比如 k t , i C {k_{t,i}^C} kt,iC 表示没有施加RoPE
的第 t t t 个token的第 i i i 个注意力头的key
, k t R {k_{t}^R} ktR 表示施加RoPE
的第 t t t 个token的key
(这部分共享一个注意力头)。完整的key
由这两部分拼接而成,这样就完成了RoPE
解耦。此时缓存部分多了一个 k t R k_t^R ktR ,由于只有一个头,所以增加的缓存并不多。
另外,论文中还提到在训练时还对 Q Q Q 做了低秩投影,减少了参数量。
4.计算优化
当然,在实际计算的时候我们不能按照上面的架构图进行计算,而是将施加和未施加RoPE
的两部分劈开,单独计算attn_weights,之后两部分相加。具体来说,这样对于未施加RoPE
的部分,我们才能进行矩阵吸收,以MQA
的形式出现;而对于施加RoPE
的部分,计算前不需要像上面架构图中将 k t R k_t^R ktR 复制 n h n_h nh 份与 { k t , i C } \{k_{t, i}^C\} {kt,iC} 拼接,节省了内存操作。
再说一下矩阵吸收顺序,是 Q K ⊤ = ( ( x Q W Q ⊤ ) W U K ) c K V ⊤ QK^\top = ((x_QW_Q^\top) W^{UK}){c^{KV}}^\top QK⊤=((xQWQ⊤)WUK)cKV⊤ 还是 Q K ⊤ = ( x Q ( W Q ⊤ W U K ) ) c K V ⊤ QK^\top = (x_Q(W_Q^\top W^{UK})){c^{KV}}^\top QK⊤=(xQ(WQ⊤WUK))cKV⊤ ,前者表示临时吸收,而后者表示预先计算 W Q ⊤ W U K W_Q^\top W^{UK} WQ⊤WUK 。后者的效果是将两个低秩矩阵拼成了一个更大的矩阵,其实整个计算过程是抬升了计算量,同时显存占用也更多。
二、代码
1.优化前
为了深入理解MLA
,我们可以找到DeepSeek-V2
或者16B参数的DeepSeek-V2-Lite
进行学习。我选择了量化版本DeepSeek-V2-Lite-AWQ,运行它只需要一张3090。
Attention部分的代码如下,已经给出了详细注释,不理解的小伙伴可以和论文中的公式对一下。需要注意的是,开源版本的代码完全按照架构图中的方式,也即以MHA
的形式进行了计算,KV Cache
部分缓存的也不是 c t K V c_t^{KV} ctKV 和 k t R k_t^R ktR 。可想而知,它没有节省缓存,也不能保证更高效的推理。
class DeepseekV2Attention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")# dropout ratio,0.0self.attention_dropout = config.attention_dropout# 隐藏层维度,2048self.hidden_size = config.hidden_size# MLA注意力头的个数,16self.num_heads = config.num_attention_heads# 最大序列长度, 163840 (160K)self.max_position_embeddings = config.max_position_embeddings# 旋转位置编码的base, 10000self.rope_theta = config.rope_theta# q对应的lora_rank,降维后的维度,None(通常训练时使用)self.q_lora_rank = config.q_lora_rank# q施加rope部分的每个注意力头的维度,k对应的维度与q相同, 64self.qk_rope_head_dim = config.qk_rope_head_dim# k(v)对应的lora_rank,降维后的维度,512self.kv_lora_rank = config.kv_lora_rank# v在每个注意力头中的维度,128self.v_head_dim = config.v_head_dim# q(k)未施加rope部分的每个注意力头的维度,128self.qk_nope_head_dim = config.qk_nope_head_dim# q的维度,等于施加和未施加rope的两部分维度之和, 128+64=192self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim# 是否单向self.is_causal = Trueif self.q_lora_rank is None:# q不做低秩压缩self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)else:# q做低秩压缩# down_proj, h_{t} -> c_{t}^{Q}self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)# up_projself.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)# kv做低秩压缩,# down_proj, h_{t} -> c_{t}^{KV} + 未施加位置编码前的k_{t}^{R}self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size,config.kv_lora_rank + config.qk_rope_head_dim,bias=config.attention_bias,)self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)# up_proj, c_{t}^{KV} -> k_{t}^{C} + v_{t}^{C}# (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) # = config.qk_nope_head_dim + self.v_head_dim,表示 k 未施加rope的部分 以及 v 的维度之和self.kv_b_proj = nn.Linear(config.kv_lora_rank,self.num_heads* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),bias=False,) # output_projself.o_proj = nn.Linear(self.num_heads * self.v_head_dim,self.hidden_size,bias=config.attention_bias,)# 初始化位置编码,计算好指定seq_len每个位置的sin和cos值self._init_rope()# 1/√dself.softmax_scale = self.q_head_dim ** (-0.5)# rope扩展方式(linear、DynamicNTK、YaRN等),YaRN在softmax_scale上需要特殊处理if self.config.rope_scaling is not None:mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)scaling_factor = self.config.rope_scaling["factor"]if mscale_all_dim:mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)self.softmax_scale = self.softmax_scale * mscale * mscaledef _init_rope(self):"""根据不同的rope_scaling使用不同的RoPE(扩展)这里使用的是YaRN"""if self.config.rope_scaling is None:self.rotary_emb = DeepseekV2RotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,base=self.rope_theta,)else:scaling_type = self.config.rope_scaling["type"]scaling_factor = self.config.rope_scaling["factor"]if scaling_type == "linear":self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,)elif scaling_type == "dynamic":self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,)elif scaling_type == "yarn":kwargs = {key: self.config.rope_scaling[key]for key in ["original_max_position_embeddings","beta_fast","beta_slow","mscale","mscale_all_dim",]if key in self.config.rope_scaling}self.rotary_emb = DeepseekV2YarnRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,**kwargs,)else:raise ValueError(f"Unknown RoPE scaling type {scaling_type}")def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):return (tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous())def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Cache] = None,output_attentions: bool = False,use_cache: bool = False,**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:if "padding_mask" in kwargs:warnings.warn("Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`")bsz, q_len, _ = hidden_states.size()# q的投影层,可以选择低秩压缩if self.q_lora_rank is None:# (bsz, q_len, hidden_size) -> (bsz, q_len, num_heads * q_head_dim)q = self.q_proj(hidden_states)else:q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))# (bsz, q_len, num_heads * q_head_dim) -> (bsz, num_heads, q_len, q_head_dim)q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)# q_nope: (bsz, num_heads, q_len, qk_nope_head_dim)# q_pe: (bsz, num_heads, q_len, qk_pe_head_dim)q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# kv lora低秩压缩# (bsz, q_len, hidden_size) -> (bsz, q_len, kv_lora_rank + qk_rope_head_dim)compressed_kv = self.kv_a_proj_with_mqa(hidden_states)# compressed_kv:(bsz, q_len, kv_lora_rank),无需位置编码的部分# k_pe: (bsz, q_len, qk_rope_head_dim),需要施加位置编码的部分compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# -> (bsz, 1, q_len, qk_rope_head_dim) 1是headk_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)# up_proj,此时的 k 还没有concat施加位置编码的部分# (bsz, q_len, kv_lora_rank) -> (bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim)# -> (bsz, num_heads, q_len, qk_nope_head_dim + v_head_dim)kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 当前=q_lenkv_seq_len = value_states.shape[-2]# 使用kvcacheif past_key_value is not None:if self.layer_idx is None:raise ValueError(f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ""for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ""with a layer index.")kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)# 计算位置编码cos和sin值cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)# 对(需要位置编码的部分)q和k施加位置编码q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)# q_nope和q_pe合并query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)query_states[:, :, :, : self.qk_nope_head_dim] = q_nopequery_states[:, :, :, self.qk_nope_head_dim :] = q_pe# k_nope和k_pe合并key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)key_states[:, :, :, : self.qk_nope_head_dim] = k_nopekey_states[:, :, :, self.qk_nope_head_dim :] = k_pe# kvcacheif past_key_value is not None:cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE modelskey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# (Q * K^{T}) / √dattn_weights = (torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale)if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"f" {attn_weights.size()}")# 添加maskassert attention_mask is not Noneif attention_mask is not None:if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")attn_weights = attn_weights + attention_mask# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)attn_output = torch.matmul(attn_weights, value_states)if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"f" {attn_output.size()}")attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)# 输出attn_output = self.o_proj(attn_output)if not output_attentions:attn_weights = Nonereturn attn_output, attn_weights, past_key_value
2.优化后
优化版本的代码参考deepseekv2-profile,作者层层递进,给出了多个优化版本的MLA
。其中absorbed_cache_compressed_move_elision也即A_CC_ME版本的MLA
效果最佳,对应的优化就是我们在原理部分所提及的。所以我们一步到位,直接贴上这部分代码(使用的参数对应236B的版本)。如对多个优化版本对比细节感兴趣的,可自行阅读。
class DeepseekAttention(nn.Module):def __init__(self, hidden_size: int, num_attention_heads: int, q_lora_rank: int, qk_rope_head_dim: int, kv_lora_rank: int, v_head_dim: int, qk_nope_head_dim: int, max_position_embeddings: int,torch_dtype: torch.dtype, attention_bias: bool = False, *args, **kwargs):super().__init__()q_head_dim = qk_nope_head_dim + qk_rope_head_dimself.hidden_size = hidden_sizeself.num_heads = num_attention_headsself.q_lora_rank = q_lora_rankself.kv_lora_rank = kv_lora_rankself.qk_rope_head_dim = qk_rope_head_dimself.qk_nope_head_dim = qk_nope_head_dimself.q_head_dim = q_head_dimself.v_head_dim = v_head_dimself.softmax_scale = torch.tensor(self.q_head_dim).to(torch_dtype).rsqrt()self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=attention_bias, dtype=torch_dtype)self.q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank).to(torch_dtype)self.q_b_proj = nn.Linear(q_lora_rank, num_attention_heads * q_head_dim, bias=False, dtype=torch_dtype)self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=attention_bias, dtype=torch_dtype)self.kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank).to(torch_dtype)self.kv_b_proj = nn.Linear(kv_lora_rank, num_attention_heads * (qk_nope_head_dim + v_head_dim), bias=False, dtype=torch_dtype)self.o_proj = nn.Linear(num_attention_heads * v_head_dim, hidden_size, bias=attention_bias, dtype=torch_dtype)self.rotary_emb = DeepseekV2RotaryEmbedding(self.qk_rope_head_dim, max_position_embeddings=max_position_embeddings).to(torch_dtype)def compress_kv(self, hidden_states_kv: torch.Tensor, kv_position_ids: torch.LongTensor) -> torch.Tensor:# return the RoPE'ed & compressed kvbsz, kv_seq_len, _ = hidden_states_kv.size()compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)compressed_kv = self.kv_a_layernorm(compressed_kv)k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)cos, sin = self.rotary_emb(k_pe) k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim)return torch.cat([compressed_kv, k_pe],dim=-1)def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):'''Attention masks and past cache are removed.Input: - hidden_states_q: [bsz, q_len, hidden_size]- compressed_kv: [bsz, kv_len, kv_lora_rank]- position_ids: [bsz, q_len]'''bsz, q_len, _ = hidden_states_q.size()q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q)))q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)# q_nope: (bsz, num_heads, q_len, qk_nope_head_dim)q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)kv_seq_len = compressed_kv.size(1)compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim)kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]cos, sin = self.rotary_emb(q_pe)q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)# 这边attn_weights是pe和nope两部分分开计算的,然后相加的。# (bsz, num_heads, q_len, qk_nope_head_dim) * (num_heads, qk_nope_head_dim, kv_lora_rank) -> (bsz, num_heads, q_len, kv_lora_rank)q_nope = torch.matmul(q_nope, q_absorb) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scaleif attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"f" {attn_weights.size()}")# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype)attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)attn_output = torch.matmul(attn_output, out_absorb.mT) # torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb)if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"f" {attn_output.size()}")attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)attn_output = self.o_proj(attn_output)return attn_output
总结
本篇讲解了DeepSeek-V2/V3/R1
核心模块之一MLA
,解释了其极致压缩KV Cache
却能避免性能损失的原因,同时说明了矩阵吸收等计算优化技巧,印证论文中efficient inference
的说法。还有一些对MLA
分析更细致或者进行更多工程优化的文章,我放到参考列表中,感兴趣的自行阅读。
参考
[1] 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
[2] DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子
[3] Deepseek MLA 一定要做吸收吗?
[4] DeepSeek-V2 MLA KV Cache 真的省了吗?