目录
- 引言
- 原始注意力
- 线性注意力
- 因果模型存在的问题
- 累加求和操作的限制
- Lightning Attention
- Lightning Attention-1
- Lightning Attention-2
- 备注
引言
MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。
那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?
本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。
原始注意力
现在主流的有两类模型,一种是应用双向注意力的bert
类模型,另一种是应用单向注意力的gpt
类模型,他们所使用的注意力其实是有细微差别的。
- 双向注意力(bert类),就是传统认知中标准的注意力
Attention ( Q , K , V ) = softmax ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(dQKT)V
- 单向注意力(因果模型,
gpt
类),只能看到当前和前面的token
,所有要在softmax
之前乘上一个掩码矩阵,M
为单向掩码矩阵
Attention ( Q , K , V ) = softmax ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(dQKT⊙M)V
其中Q、K、V
每个矩阵的维度都是[n, d]
,即[序列长度,隐层维度]
,此时 Q K T QK^T QKT的维度是[n, n]
,所以整体复杂度是 O ( n 2 d ) O(n^2d) O(n2d)。其中d是固定大小, n 2 n^2 n2随着序列长度平方增加,就主导了整体的复杂度。
线性注意力
原始注意力中softmax的作用主要是引入非线性(取概率化再与V乘都是次要的),那就可以将其换成其他的非线性激活函数。
Attention ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V
这里的 ϕ \phi ϕ代表所使用的激活函数,有很多种可以选择(论文常用的有1+elu
)。这里的归一化就先省略掉了,有一些论文就将K矩阵的归一化放到分母上(或者说K矩阵归一化的逆)。
此时观察,使用softmax必须等 Q K T QK^T QKT先计算完,而使用其他的激活函数只对单个Q或者K进行运算,不需要绑定 Q K T QK^T QKT。所以就可以将左乘变成右乘
( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
此时 ϕ ( K ) T V \phi(K)^TV ϕ(K)TV的复杂度是 O ( d 2 ) O(d^2) O(d2),所以整体复杂度变成了 O ( n d 2 ) O(nd^2) O(nd2),随着序列长度n
线性增长,此时就是线性注意力了。
(可选):通常线性注意力的公式还有如下形式
O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ−1∗(Q∗KT∗V)
(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT∗1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ−1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。
因果模型存在的问题
注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的 q t q_t qt乘以前面的 k j k_j kj,再乘以 v j v_j vj累加求和。此时 Q K T QK^T QKT可以正常进行矩阵运算,然后使用 ⊙ \odot ⊙(Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。
O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKT⊙M)V
o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1∑t(qtTkj)vj
此时注意,上面公式的运算涉及 ⊙ \odot ⊙,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KT⊙MV)。 ⊙ \odot ⊙是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。
累加求和操作的限制
双向注意力模型(bert)中使用的线性注意力如下,可以先算KV
( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图
- 双向注意力计算
K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]
K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12⋮k1dk21k22⋮k2dk31k32⋮k3dk41k42⋮k4d
V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d
K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]= [KTV]1[KTV]2⋮[KTV]d
此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。
q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 ⋮ q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} \\ q_3[K^{T}V]_{2} \\ \vdots \\ q_3[K^{T}V]_{d} \\ \end{bmatrix} q3KTV=q3 [KTV]1[KTV]2⋮[KTV]d = q3[KTV]1q3[KTV]2⋮q3[KTV]d
也可以使用以下代码测试
q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4],[5, 5, 5, 5],[6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]])print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q, kT @ v)
print('result', result)
此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。
但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)
o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)
o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)
o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)
这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。
Lightning Attention
到这里可以引出MiniMax-01
中所使用的Lightning Attention
了,但其实这个注意力有两个版本,MiniMax-01
中所提到的就是是Lightning Attention-2
,那咱们先看看第一个版本做了什么。
Lightning Attention-1
源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer
Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ⊙ M ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)T⊙M)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention
中的硬件加速。
其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是 O ( n 2 d ) O(n^2d) O(n2d)。
Lightning Attention-2
源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
Lightning Attention-2
解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为 O ( n d 2 ) O(nd^2) O(nd2)。
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)
o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)
o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)
o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)
再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1和 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。
公式中也可以发现,当前step之前的k和v是可以相乘的,比如 q 3 q_3 q3在计算时,可以将 k 1 T v 1 + k 2 T v 2 + k 3 T v 3 k_1^Tv_1+k_2^Tv_2+k_3^Tv_3 k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。
此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。
块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。
下图是Lightning Attention-2
的结构图, λ \lambda λ是它的模型所使用的位置编码,忽略即可。
以下是前向传播和反向传播流程。
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了
解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。
备注
个人理解,若有不对请指出,谢谢。