大模型推理 memory bandwidth bound (3) - MLA

系列文章目录

大模型推理 & memory bandwidth bound (1) - 性能瓶颈与优化概述
大模型推理 & memory bandwidth bound (2) - Multi-Query Attention
大模型推理 & memory bandwidth bound (3) - MLA


文章目录

  • 系列文章目录
  • 前言
  • 一、原理
    • 1.低秩压缩 & 动机
    • 2.矩阵吸收
    • 3.RoPE解耦
    • 4.计算优化
  • 二、代码
    • 1.优化前
    • 2.优化后
  • 总结
  • 参考


前言

“MLA guarantees efficient inference through significantly compressing the Key-Value (KV) cache into a latent vector, while DeepSeekMoE enables training strong models at an economical cost through sparse computation.” —— 《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》

在上一篇,我们解读了《Fast Transformer Decoding: One Write-Head is All You Need》这篇文章,确认了MHADecode阶段memory bandwidth bound的问题,分析了MQA方法带来的优化。具体来说,MQAKV Cache压缩成了原来 1 h \frac{1}{h} h1 h h h 为注意力头的个数),增加了计算强度,缓解了memory bandwidth bound问题;同时,KV Cache的减少意味着可以增加batch size,也即可以同时处理更多的requests
今天要讨论的是Multi-head Latent Attention (MLA),它是另一种极致压缩KV Cache的方法,在DeepSeek V2中被提出,也是当前爆火的DeepSeek-V3/R1模型的核心模块之一。


一、原理

1.低秩压缩 & 动机

在这里插入图片描述
如上图所示,不同于MHA以及多头共用KVGQAMQA,MLA使用的是KV联合低秩压缩方法,在推理时只需要存储压缩的这部分,也就是图中的Compressed Latent KV。通过投影矩阵计算得到MLA中的KV(图中蓝色长条)。用公式表达如下:
c t K V = W D K V h t \textbf{c}_t^{KV} = W^{DKV}\textbf{h}_t ctKV=WDKVht
k t = W U K c t K V \textbf{k}_t = W^{UK}\textbf{c}_t^{KV} kt=WUKctKV
v t = W U V c t K V \textbf{v}_t = W^{UV}\textbf{c}_t^{KV} vt=WUVctKV

其中下标 t t t 表示第 t t t 个token, h t \textbf{h}_t ht 是注意力模块的输入, c t K V \textbf{c}_t^{KV} ctKV是低秩压缩的隐向量,也就是推理时需要被缓存的部分。 W D K V ∈ R d c × d W^{DKV} \in \mathbb{R}^{d_c \times d} WDKVRdc×d 是 down-projection matrix(不知道怎么翻译),即降维投影矩阵,要求 d c < d d_c < d dc<d W U K , W U V ∈ R d h n h × d c W^{UK}, W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} WUK,WUVRdhnh×dc 是 up-projection matrix,升维投影矩阵,满足 d c < d h n h d_c< d_hn_h dc<dhnh d d d 是输入的维度, d c d_c dcCompressed Latent KV的维度, n h n_h nh 是注意力头的个数, d h d_h dh 是每个注意力头的维度。
我们来猜一下设计MLA的动机。想压缩KV Cache是显而易见的,这一点和MQA接近;MQA相较于MHA的劣势是有性能损失,而MLA最后的形式等同于MHA所以我们斗胆下个结论:MLAMQAMHA的杂交品种,要同时继承MQA压缩KV cacheMHA多个注意力头的优异性能这两个优点。

2.矩阵吸收

但你有没有发现一个问题,尽管上图告诉你MLA只需要缓存Compressed Latent KV,但是它最后展开成MHA了,这部分 keys 和 values 还是要占用显存,同时又多了一步低秩压缩,增加了内存操作。那么怎么解决这个问题呢?论文中给出了答案 —— 矩阵吸收。
我们先来梳理一下MLA的计算过程,相应的公式按照神经网络中的表达方式(不同于上面或者论文中公式矩阵左乘的方式),比如nn.linear()对应的表示为
y = x W ⊤ y= xW^\top y=xW
那么 Q K ⊤ QK^\top QK 的一般表达式为
Q K ⊤ = ( x Q W Q ⊤ ) ( x K W K ⊤ ) ⊤ = ( x Q W Q ⊤ ) ( W K x K ⊤ ) = ( x Q W Q ⊤ W K ) x K ⊤ \begin{aligned} QK^\top & = (x_QW_Q^\top)(x_KW_K^\top)^\top \\ & = (x_QW_Q^\top)(W_Kx_K^\top) \\ & = (x_QW_Q^\top W_K)x_K^\top \end{aligned} QK=(xQWQ)(xKWK)=(xQWQ)(WKxK)=(xQWQWK)xK

放到MLA场景中进行变量替换 W K → W U K W_K \rightarrow W^{UK} WKWUK x K → c K V x_K \rightarrow c^{KV} xKcKV 得到

Q K ⊤ = ( x Q W Q ⊤ W U K ) c K V ⊤ QK^\top = (x_QW_Q^\top W^{UK}){c^{KV}}^\top QK=(xQWQWUK)cKV

注意这边偷懒了,因为 W U K , W U V ∈ R d h n h × d c W^{UK}, W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} WUK,WUVRdhnh×dc ,实际是需要展开成 n h n_h nh 个头的。 由于矩阵乘法结合律,我们可以让括号中的作为一个整体,也就是将矩阵 W U K W^{UK} WUK 进行了吸收。为什么要这样做,我们等会解释。

接下来的计算输出
P = S o f t m a x ( Q K ⊤ d h ) P = Softmax(\frac{QK^\top}{\sqrt{d_h}}) P=Softmax(dh QK)
O = L i n e a r ( P V ) = ( P ( x V W V ⊤ ) ) W O ⊤ \begin{aligned} O & = Linear(PV)\\ & = (P(x_V W_V ^\top))W_O^\top \end{aligned} O=Linear(PV)=(P(xVWV))WO

放到MLA场景中进行变量替换得到

O = P c K V ( W U V ⊤ W O ⊤ ) O = Pc^{KV} ({W^{UV}} ^\top{W^O}^\top) O=PcKV(WUVWO)

同样的,这边的计算实际也要reshape成 n h n_h nh 个头。 因此对于V,我们也可以使用矩阵吸收,只不过矩阵 W U V W^{UV} WUV 是在attention输出部分进行吸收。
这样的好处是缓存和计算都基于联合压缩 c K V c^{KV} cKVKV cache显存占用减少),计算形式上是KV共用的MQA(计算强度增加),但拥有MHA多个注意力头这样更加优异的性能(这部分信息隐藏在矩阵 W U K W^{UK} WUK W U V W^{UV} WUV 中)。

3.RoPE解耦

上述分析忽略了旋转位置编码RoPE。考虑到RoPE为矩阵 R m ∈ R d h × d h R_m \in \mathbb{R}^{d_h \times d_h} RmRdh×dh m m m 表示位置索引, d h d_h dh 和上面一样,是每个头的维度,有 R m R n ⊤ = R m − n R_mR_n^\top=R_{m-n} RmRn=Rmn 。这边考虑单个头(没办法,不然不好表示) q m k n ⊤ q_m k_n^ \top qmkn 的计算, m m m n n n 是位置索引。

q m k n ⊤ = ( x m W q ⊤ R m ⊤ ) ( x n W k ⊤ R n ⊤ ) ⊤ = x m ( W q ⊤ R m ⊤ R n W k ) x n ⊤ = x m ( W q ⊤ R n − m ⊤ W k ) x n ⊤ \begin{aligned} q_mk_n^\top & = (x_mW_q^\top R_m^ \top)(x_nW_k^\top R_n^ \top)^\top \\ & = x_m(W_q^\top R_m^ \top R_n W_k)x_n^\top \\ & = x_m(W_q^\top R_{n-m}^ \top W_k)x_n^\top \end{aligned} qmkn=(xmWqRm)(xnWkRn)=xm(WqRmRnWk)xn=xm(WqRnmWk)xn

这样一来,矩阵吸收就破产了,因为中间的 R n − m R_{n-m} Rnm 是和位置差相关的,即随着位置差是变动的。
DeepSeek团队给出的解决方案是对RoPE解耦,导致我们最终看到的MLA架构如下图所示。简单解释一下,下图中上标带 R R R 的表示施加了RoPE,而上标带 C C C 的则是未施加RoPE的部分。比如 k t , i C {k_{t,i}^C} kt,iC 表示没有施加RoPE的第 t t t 个token的第 i i i 个注意力头的key k t R {k_{t}^R} ktR 表示施加RoPE的第 t t t 个token的key(这部分共享一个注意力头)。完整的key由这两部分拼接而成,这样就完成了RoPE解耦。此时缓存部分多了一个 k t R k_t^R ktR ,由于只有一个头,所以增加的缓存并不多。
在这里插入图片描述
另外,论文中还提到在训练时还对 Q Q Q 做了低秩投影,减少了参数量。

4.计算优化

当然,在实际计算的时候我们不能按照上面的架构图进行计算,而是将施加和未施加RoPE的两部分劈开,单独计算attn_weights,之后两部分相加。具体来说,这样对于未施加RoPE的部分,我们才能进行矩阵吸收,以MQA的形式出现;而对于施加RoPE的部分,计算前不需要像上面架构图中将 k t R k_t^R ktR 复制 n h n_h nh 份与 { k t , i C } \{k_{t, i}^C\} {kt,iC} 拼接,节省了内存操作。
再说一下矩阵吸收顺序,是 Q K ⊤ = ( ( x Q W Q ⊤ ) W U K ) c K V ⊤ QK^\top = ((x_QW_Q^\top) W^{UK}){c^{KV}}^\top QK=((xQWQ)WUK)cKV 还是 Q K ⊤ = ( x Q ( W Q ⊤ W U K ) ) c K V ⊤ QK^\top = (x_Q(W_Q^\top W^{UK})){c^{KV}}^\top QK=(xQ(WQWUK))cKV ,前者表示临时吸收,而后者表示预先计算 W Q ⊤ W U K W_Q^\top W^{UK} WQWUK 。后者的效果是将两个低秩矩阵拼成了一个更大的矩阵,其实整个计算过程是抬升了计算量,同时显存占用也更多。

二、代码

1.优化前

为了深入理解MLA,我们可以找到DeepSeek-V2或者16B参数的DeepSeek-V2-Lite进行学习。我选择了量化版本DeepSeek-V2-Lite-AWQ,运行它只需要一张3090。
Attention部分的代码如下,已经给出了详细注释,不理解的小伙伴可以和论文中的公式对一下。需要注意的是,开源版本的代码完全按照架构图中的方式,也即以MHA的形式进行了计算,KV Cache部分缓存的也不是 c t K V c_t^{KV} ctKV k t R k_t^R ktR 。可想而知,它没有节省缓存,也不能保证更高效的推理。

class DeepseekV2Attention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")# dropout ratio,0.0self.attention_dropout = config.attention_dropout# 隐藏层维度,2048self.hidden_size = config.hidden_size# MLA注意力头的个数,16self.num_heads = config.num_attention_heads# 最大序列长度, 163840 (160K)self.max_position_embeddings = config.max_position_embeddings# 旋转位置编码的base, 10000self.rope_theta = config.rope_theta# q对应的lora_rank,降维后的维度,None(通常训练时使用)self.q_lora_rank = config.q_lora_rank# q施加rope部分的每个注意力头的维度,k对应的维度与q相同, 64self.qk_rope_head_dim = config.qk_rope_head_dim# k(v)对应的lora_rank,降维后的维度,512self.kv_lora_rank = config.kv_lora_rank# v在每个注意力头中的维度,128self.v_head_dim = config.v_head_dim# q(k)未施加rope部分的每个注意力头的维度,128self.qk_nope_head_dim = config.qk_nope_head_dim# q的维度,等于施加和未施加rope的两部分维度之和, 128+64=192self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim# 是否单向self.is_causal = Trueif self.q_lora_rank is None:# q不做低秩压缩self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)else:# q做低秩压缩# down_proj, h_{t} -> c_{t}^{Q}self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)# up_projself.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)# kv做低秩压缩,# down_proj, h_{t} -> c_{t}^{KV} + 未施加位置编码前的k_{t}^{R}self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size,config.kv_lora_rank + config.qk_rope_head_dim,bias=config.attention_bias,)self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)# up_proj, c_{t}^{KV} -> k_{t}^{C} + v_{t}^{C}# (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) # = config.qk_nope_head_dim + self.v_head_dim,表示 k 未施加rope的部分 以及 v 的维度之和self.kv_b_proj = nn.Linear(config.kv_lora_rank,self.num_heads* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),bias=False,)  # output_projself.o_proj = nn.Linear(self.num_heads * self.v_head_dim,self.hidden_size,bias=config.attention_bias,)# 初始化位置编码,计算好指定seq_len每个位置的sin和cos值self._init_rope()# 1/√dself.softmax_scale = self.q_head_dim ** (-0.5)# rope扩展方式(linear、DynamicNTK、YaRN等),YaRN在softmax_scale上需要特殊处理if self.config.rope_scaling is not None:mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)scaling_factor = self.config.rope_scaling["factor"]if mscale_all_dim:mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)self.softmax_scale = self.softmax_scale * mscale * mscaledef _init_rope(self):"""根据不同的rope_scaling使用不同的RoPE(扩展)这里使用的是YaRN"""if self.config.rope_scaling is None:self.rotary_emb = DeepseekV2RotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,base=self.rope_theta,)else:scaling_type = self.config.rope_scaling["type"]scaling_factor = self.config.rope_scaling["factor"]if scaling_type == "linear":self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,)elif scaling_type == "dynamic":self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,)elif scaling_type == "yarn":kwargs = {key: self.config.rope_scaling[key]for key in ["original_max_position_embeddings","beta_fast","beta_slow","mscale","mscale_all_dim",]if key in self.config.rope_scaling}self.rotary_emb = DeepseekV2YarnRotaryEmbedding(self.qk_rope_head_dim,max_position_embeddings=self.max_position_embeddings,scaling_factor=scaling_factor,base=self.rope_theta,**kwargs,)else:raise ValueError(f"Unknown RoPE scaling type {scaling_type}")def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):return (tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous())def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Cache] = None,output_attentions: bool = False,use_cache: bool = False,**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:if "padding_mask" in kwargs:warnings.warn("Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`")bsz, q_len, _ = hidden_states.size()# q的投影层,可以选择低秩压缩if self.q_lora_rank is None:# (bsz, q_len, hidden_size) -> (bsz, q_len, num_heads * q_head_dim)q = self.q_proj(hidden_states)else:q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))# (bsz, q_len, num_heads * q_head_dim) -> (bsz, num_heads, q_len, q_head_dim)q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)# q_nope: (bsz, num_heads, q_len, qk_nope_head_dim)# q_pe: (bsz, num_heads, q_len, qk_pe_head_dim)q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# kv lora低秩压缩# (bsz, q_len, hidden_size) -> (bsz, q_len, kv_lora_rank + qk_rope_head_dim)compressed_kv = self.kv_a_proj_with_mqa(hidden_states)# compressed_kv:(bsz, q_len, kv_lora_rank),无需位置编码的部分# k_pe: (bsz, q_len, qk_rope_head_dim),需要施加位置编码的部分compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# -> (bsz, 1, q_len, qk_rope_head_dim) 1是headk_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)# up_proj,此时的 k 还没有concat施加位置编码的部分# (bsz, q_len, kv_lora_rank) -> (bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim)# -> (bsz, num_heads, q_len, qk_nope_head_dim + v_head_dim)kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 当前=q_lenkv_seq_len = value_states.shape[-2]# 使用kvcacheif past_key_value is not None:if self.layer_idx is None:raise ValueError(f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ""for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ""with a layer index.")kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)# 计算位置编码cos和sin值cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)# 对(需要位置编码的部分)q和k施加位置编码q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)# q_nope和q_pe合并query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)query_states[:, :, :, : self.qk_nope_head_dim] = q_nopequery_states[:, :, :, self.qk_nope_head_dim :] = q_pe# k_nope和k_pe合并key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)key_states[:, :, :, : self.qk_nope_head_dim] = k_nopekey_states[:, :, :, self.qk_nope_head_dim :] = k_pe# kvcacheif past_key_value is not None:cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE modelskey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# (Q * K^{T}) / √dattn_weights = (torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale)if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"f" {attn_weights.size()}")# 添加maskassert attention_mask is not Noneif attention_mask is not None:if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")attn_weights = attn_weights + attention_mask# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)attn_output = torch.matmul(attn_weights, value_states)if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"f" {attn_output.size()}")attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)# 输出attn_output = self.o_proj(attn_output)if not output_attentions:attn_weights = Nonereturn attn_output, attn_weights, past_key_value

2.优化后

优化版本的代码参考deepseekv2-profile,作者层层递进,给出了多个优化版本的MLA。其中absorbed_cache_compressed_move_elision也即A_CC_ME版本的MLA效果最佳,对应的优化就是我们在原理部分所提及的。所以我们一步到位,直接贴上这部分代码(使用的参数对应236B的版本)。如对多个优化版本对比细节感兴趣的,可自行阅读。

class DeepseekAttention(nn.Module):def __init__(self, hidden_size: int, num_attention_heads: int, q_lora_rank: int, qk_rope_head_dim: int, kv_lora_rank: int, v_head_dim: int, qk_nope_head_dim: int, max_position_embeddings: int,torch_dtype: torch.dtype, attention_bias: bool = False, *args, **kwargs):super().__init__()q_head_dim = qk_nope_head_dim + qk_rope_head_dimself.hidden_size = hidden_sizeself.num_heads = num_attention_headsself.q_lora_rank = q_lora_rankself.kv_lora_rank = kv_lora_rankself.qk_rope_head_dim = qk_rope_head_dimself.qk_nope_head_dim = qk_nope_head_dimself.q_head_dim = q_head_dimself.v_head_dim = v_head_dimself.softmax_scale = torch.tensor(self.q_head_dim).to(torch_dtype).rsqrt()self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=attention_bias, dtype=torch_dtype)self.q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank).to(torch_dtype)self.q_b_proj = nn.Linear(q_lora_rank, num_attention_heads * q_head_dim, bias=False, dtype=torch_dtype)self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=attention_bias, dtype=torch_dtype)self.kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank).to(torch_dtype)self.kv_b_proj = nn.Linear(kv_lora_rank, num_attention_heads * (qk_nope_head_dim + v_head_dim), bias=False, dtype=torch_dtype)self.o_proj = nn.Linear(num_attention_heads * v_head_dim, hidden_size, bias=attention_bias, dtype=torch_dtype)self.rotary_emb = DeepseekV2RotaryEmbedding(self.qk_rope_head_dim, max_position_embeddings=max_position_embeddings).to(torch_dtype)def compress_kv(self, hidden_states_kv: torch.Tensor, kv_position_ids: torch.LongTensor) -> torch.Tensor:# return the RoPE'ed & compressed kvbsz, kv_seq_len, _ = hidden_states_kv.size()compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)compressed_kv = self.kv_a_layernorm(compressed_kv)k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)cos, sin = self.rotary_emb(k_pe) k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim)return torch.cat([compressed_kv, k_pe],dim=-1)def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):'''Attention masks and past cache are removed.Input: - hidden_states_q: [bsz, q_len, hidden_size]- compressed_kv: [bsz, kv_len, kv_lora_rank]- position_ids: [bsz, q_len]'''bsz, q_len, _ = hidden_states_q.size()q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q)))q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)# q_nope: (bsz, num_heads, q_len, qk_nope_head_dim)q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)kv_seq_len = compressed_kv.size(1)compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim)kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]cos, sin = self.rotary_emb(q_pe)q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)# 这边attn_weights是pe和nope两部分分开计算的,然后相加的。# (bsz, num_heads, q_len, qk_nope_head_dim) * (num_heads, qk_nope_head_dim, kv_lora_rank) -> (bsz, num_heads, q_len, kv_lora_rank)q_nope = torch.matmul(q_nope, q_absorb) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scaleif attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"f" {attn_weights.size()}")# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype)attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)attn_output = torch.matmul(attn_output, out_absorb.mT) # torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb)if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"f" {attn_output.size()}")attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)attn_output = self.o_proj(attn_output)return attn_output

总结

本篇讲解了DeepSeek-V2/V3/R1核心模块之一MLA,解释了其极致压缩KV Cache却能避免性能损失的原因,同时说明了矩阵吸收等计算优化技巧,印证论文中efficient inference的说法。还有一些对MLA分析更细致或者进行更多工程优化的文章,我放到参考列表中,感兴趣的自行阅读。

参考

[1] 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
[2] DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子
[3] Deepseek MLA 一定要做吸收吗?
[4] DeepSeek-V2 MLA KV Cache 真的省了吗?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34641.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CTP开发爬坑指北(九)

CTP API开发中有很多需要注意的小细节&#xff0c;稍有不慎就会出问题&#xff0c;不然&#xff0c;轻则表现与预期不符&#xff0c;重则程序崩溃影响策略盈利。本系列将容易遇到的坑列出来&#xff0c;以供开发时参考&#xff0c;如有疑义之处&#xff0c;欢迎指正。 在国内期…

python_巨潮年报pdf下载

目录 前置&#xff1a; 步骤&#xff1a; step one: pip安装必要包&#xff0c;获取年报url列表 step two: 将查看url列表转换为pdf url step three: 多进程下载pdf 前置&#xff1a; 1 了解一些股票的基本面需要看历年年报&#xff0c;在巨潮一个个下载比较费时间&…

量化交易backtrader实践(五)_策略综合篇(3)_经典策略复盘

01_经典策略复盘 在某款股票软件手机版App上&#xff0c;有一项“复盘”的功能&#xff0c;这个功能很强大&#xff0c;它能把这支股票近1年的走势&#xff0c;用设置好的六个策略去回测&#xff0c;得到每个策略的近一年的收益率&#xff0c;并做了从最好到最差的排序。这就能…

蓝桥与力扣刷题(蓝桥 字符统计)

题目&#xff1a;给定一个只包含大写字母的字符出 S, 请你输出其中出现次数最多的字符。如果有多个字母均出现了最多次, 按字母表顺序依次输出所有这些字母。 输入格式 一个只包含大写字母的字等串 S. 输出格式 若干个大写字母&#xff0c;代表答案。 样例输入 BABBACAC样…

protobuf安装

安装 github官方链接 https://github.com/protocolbuffers/protobuf/ 以protobuf21为例 https://github.com/protocolbuffers/protobuf/releases/download/v21.11/protobuf-all-21.11.zip windows 解压好文件夹后,使用cmake,vs,qt creator等工具打开该项目,进行编译,编译需…

Compose 实践与探索八 —— LayoutModifier 解析

前面几节讲的 Modifier 都是起辅助作用的&#xff0c;比如 Modifier 的伴生对象、CombinedModifier、 ComposedModifier 以及几乎所有 Modifier 的父接口 Modifier.Element。本篇我们开始讲具有直接功效的 Modifier&#xff0c;分为几个大类&#xff1a;LayoutModifier、DrawMo…

stl之string的详解

一&#xff0c;string定义的方式 &#xff0c;string定义了多种函数重载的方式&#xff0c;常用的构造函数如下&#xff1a; string(); string(const string& str); string(const string& str, size_t pos, size_t len npos); string(const char* s); string(const …

Leetcode-131.Palindrome Partitioning [C++][Java]

目录 一、题目描述 二、解题思路 【C】 【Java】 Leetcode-131.Palindrome Partitioninghttps://leetcode.com/problems/palindrome-partitioning/description/131. 分割回文串 - 力扣&#xff08;LeetCode&#xff09;131. 分割回文串 - 给你一个字符串 s&#xff0c;请你…

InternVL:论文阅读 -- 多模态大模型(视觉语言模型)

更多内容&#xff1a;XiaoJ的知识星球 文章目录 InternVL: 扩展视觉基础模型与通用视觉语言任务对齐1.概述2.InternVL整体架构1&#xff09;大型视觉编码器&#xff1a;InternViT-6B2&#xff09;语言中间件&#xff1a;QLLaMA。3&#xff09;训练策略&#xff08;1&#xff09…

【AWS入门】AWS云计算简介

【AWS入门】AWS云计算简介 A Brief Introduction to AWS Cloud Computing By JacksonML 什么是云计算&#xff1f;云计算能干什么&#xff1f;我们如何利用云计算&#xff1f;云计算如何实现&#xff1f; 带着一系列问题&#xff0c;我将做一个普通布道者&#xff0c;引领广…

二分算法刷题

1. 初识 总结&#xff1a;二分算法题的细节非常多&#xff0c;容易写出死循环。使用算法的条件不一定是数组有序&#xff0c;而是具有“二断性”&#xff1b;模板三种后面会讲。 朴素二分二分查找左端点二分查找右端点 2. 朴素二分 题目链接&#xff1a;704. 二分查找 - 力扣…

itsdangerous加解密源码分析|BUG汇总

这是我这两天的思考 早知道密码学的课就不旷那么多了 纯个人见解 如需转载&#xff0c;标记出处 目录 一、官网介绍 二、事例代码 源码分析&#xff1a; 加密函数dump源码使用的函数如下&#xff1a; 解密 ​编辑 ​编辑 关于签名&#xff1a; 为什么这个数字签名没有…

深度解析React Native底层核心架构

React Native 工作原理深度解析 一、核心架构&#xff1a;三层异构协作体系 React Native 的跨平台能力源于其独特的 JS层-Shadow层-Native层 架构设计&#xff0c;三者在不同线程中协同工作&#xff1a; JS层 运行于JavaScriptCore&#xff08;iOS&#xff09;或Hermes&…

前端内存优化实战指南:从内存泄漏到性能巅峰

前端内存优化实战指南&#xff1a;从内存泄漏到性能巅峰 一、内存问题引发的场景 1.1 典型内存灾难现场 // 经典内存泄漏示例 const zombieElements new Set();function createLeak() {const div document.createElement(div);zombieElements.add(div); // 元素永不释放div…

【工作记录】pytest使用总结

1、 fixture夹具 可参考&#xff1a; python3.x中 pytest之fixture - 漂泊的小虎 - 博客园 fixture是指夹具&#xff08;把用例夹在中间&#xff09;&#xff0c;它包括前置工作和后置工作&#xff0c;前置是用例代码的准备阶段&#xff0c;后置是用例执行之后的清理阶段,用…

C++基础笔记

1. C关键字 这个不多说&#xff0c;以后接触得到&#xff0c;但这里做个总结&#xff1a; 2. 命名空间 一般类型&#xff1a; namespace Xianyu {// 命名空间中可以定义变量/函数/类型int rand 10;int Add(int left, int right){return left right;}struct Node{struct No…

生活中的可靠性小案例12:类肤材质老化发粘问题

我一直觉得我买的某品牌车载吸尘器很好用&#xff0c;用了几年&#xff0c;目前性能也是杠杠的。然而它现在有个最大的问题&#xff0c;就是表面发粘了&#xff0c;用起来粘手&#xff0c;非常不舒服。 这一类问题在生活中不少见&#xff0c;尤其是一些用了类肤材质涂层的物件。…

黑马node.js教程(nodejs教程)——AJAX-Day01-04.案例_地区查询——查询某个省某个城市所有地区(代码示例)

文章目录 代码示例效果 代码示例 axiosTest.html <!DOCTYPE html> <!-- 文档类型声明&#xff0c;告诉浏览器这是一个HTML5文档 --> <html lang"en"> <!-- HTML根元素&#xff0c;设置文档语言为英语 --><head> <!-- 头部区域&am…

Ollama+OpenWebUI本地部署大模型

OllamaOpenWebUI本地部署大模型 前言Ollama使用Ollama安装Ollama修改配置Ollama 拉取远程大模型Ollama 构建本地大模型Ollama 运行本地模型&#xff1a;命令行交互Api调用Web 端调用 总结 前言 Ollama是一个开源项目&#xff0c;用于在本地计算机上运行大型语言模型&#xff0…

【NeurIPS 2024】LLM-ESR:用大语言模型破解序列推荐的长尾难题

标题期刊年份关键词LLM-ESR: Large Language Models Enhancement for Long-tailed Sequential RecommendationNeurIPS2024Large Language Models, Sequential Recommendation, Long-tailed &#x1f4da;研究背景 在电商和社交媒体的世界里&#xff0c;序列推荐系统&#xff…