MiniMax-01中Lightning Attention的由来(线性注意力进化史)

目录

  • 引言
  • 原始注意力
  • 线性注意力
  • 因果模型存在的问题
  • 累加求和操作的限制
  • Lightning Attention
    • Lightning Attention-1
    • Lightning Attention-2
  • 备注

引言

MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。

那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?

本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。

原始注意力

现在主流的有两类模型,一种是应用双向注意力的bert类模型,另一种是应用单向注意力的gpt类模型,他们所使用的注意力其实是有细微差别的。

  • 双向注意力(bert类),就是传统认知中标准的注意力

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKT)V

  • 单向注意力(因果模型,gpt类),只能看到当前和前面的token,所有要在softmax之前乘上一个掩码矩阵,M为单向掩码矩阵

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKTM)V

其中Q、K、V每个矩阵的维度都是[n, d],即[序列长度,隐层维度],此时 Q K T QK^T QKT的维度是[n, n],所以整体复杂度是 O ( n 2 d ) O(n^2d) O(n2d)。其中d是固定大小, n 2 n^2 n2随着序列长度平方增加,就主导了整体的复杂度。

线性注意力

原始注意力中softmax的作用主要是引入非线性(取概率化再与V乘都是次要的),那就可以将其换成其他的非线性激活函数。
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V
这里的 ϕ \phi ϕ代表所使用的激活函数,有很多种可以选择(论文常用的有1+elu)。这里的归一化就先省略掉了,有一些论文就将K矩阵的归一化放到分母上(或者说K矩阵归一化的逆)。

此时观察,使用softmax必须等 Q K T QK^T QKT先计算完,而使用其他的激活函数只对单个Q或者K进行运算,不需要绑定 Q K T QK^T QKT。所以就可以将左乘变成右乘
( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
此时 ϕ ( K ) T V \phi(K)^TV ϕ(K)TV的复杂度是 O ( d 2 ) O(d^2) O(d2),所以整体复杂度变成了 O ( n d 2 ) O(nd^2) O(nd2),随着序列长度n线性增长,此时就是线性注意力了。

(可选):通常线性注意力的公式还有如下形式

O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ1(QKTV)

(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。

因果模型存在的问题

注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的 q t q_t qt乘以前面的 k j k_j kj,再乘以 v j v_j vj累加求和。此时 Q K T QK^T QKT可以正常进行矩阵运算,然后使用 ⊙ \odot (Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。

O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKTM)V

o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1t(qtTkj)vj

此时注意,上面公式的运算涉及 ⊙ \odot ,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KTMV) ⊙ \odot 是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。

累加求和操作的限制

双向注意力模型(bert)中使用的线性注意力如下,可以先算KV

( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图

在这里插入图片描述

  • 双向注意力计算
    K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]

K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12k1dk21k22k2dk31k32k3dk41k42k4d

V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d

K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]= [KTV]1[KTV]2[KTV]d

此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。

q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 ⋮ q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} \\ q_3[K^{T}V]_{2} \\ \vdots \\ q_3[K^{T}V]_{d} \\ \end{bmatrix} q3KTV=q3 [KTV]1[KTV]2[KTV]d = q3[KTV]1q3[KTV]2q3[KTV]d

也可以使用以下代码测试

q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4],[5, 5, 5, 5],[6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]])print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q, kT @ v)
print('result', result)

此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。

但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。

o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。

Lightning Attention

到这里可以引出MiniMax-01 中所使用的Lightning Attention了,但其实这个注意力有两个版本,MiniMax-01中所提到的就是是Lightning Attention-2,那咱们先看看第一个版本做了什么。

Lightning Attention-1

源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer

Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ⊙ M ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)TM)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention中的硬件加速。

其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是 O ( n 2 d ) O(n^2d) O(n2d)
在这里插入图片描述
在这里插入图片描述

Lightning Attention-2

源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Lightning Attention-2解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为 O ( n d 2 ) O(nd^2) O(nd2)
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。

公式中也可以发现,当前step之前的k和v是可以相乘的,比如 q 3 q_3 q3在计算时,可以将 k 1 T v 1 + k 2 T v 2 + k 3 T v 3 k_1^Tv_1+k_2^Tv_2+k_3^Tv_3 k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
在这里插入图片描述
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。

此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。

块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。

下图是Lightning Attention-2的结构图, λ \lambda λ是它的模型所使用的位置编码,忽略即可。
在这里插入图片描述
以下是前向传播和反向传播流程。
在这里插入图片描述
在这里插入图片描述
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了

解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。

备注

个人理解,若有不对请指出,谢谢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/7954.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 内核进程调度

一、进程的分类 在CPU的角度看进程行为的话,可以分为两类: CPU消耗型:此类进程就是一直占用CPU计算,CPU利用率很高。IO消耗型:此类进程会涉及到IO,需要和用户交互,比如键盘输入,占用…

BLE透传方案,IoT短距无线通信的“中坚力量”

在物联网(IoT)短距无线通信生态系统中,低功耗蓝牙(BLE)数据透传是一种无需任何网络或基础设施即可完成双向通信的技术。其主要通过简单操作串口的方式进行无线数据传输,最高能满足2Mbps的数据传输速率&…

Linux 入门 常用指令 详细版

欢迎来到指令小仓库!! 宝剑锋从磨砺出,梅花香自苦寒来 什么是指令? 指令和可执行程序都是可以被执行的-->指令就是可执行程序。 指令一定是在系统的每一个位置存在的。 1.ls指令 语法: ls [选项][目…

Node.js下载安装及环境配置

目录 一、下载 1. 查看电脑版本,下载对应的安装包 2. 下载路径下载 | Node.js 中文网 二、安装步骤 1. 双击安装包 2. 点击Next下一步 3. 选择安装路径 4. 这里我选择默认配置,继续Next下一步(大家按需选择) 5. 最后inst…

为什么在编程中cast有强制类型转换的意思?

C语言或C在编程时,常常遇到“XXX without a cast”的警告信息,意思是 XXX 没有进行显示的强制类似转换,那么cast为什么会有强制类型转换的意思呢? 从英语的本义来看,cast 有“塑造、铸造”的意思。引申到编程中&#…

Spring Boot(6)解决ruoyi框架连续快速发送post请求时,弹出“数据正在处理,请勿重复提交”提醒的问题

一、整个前言 在基于 Ruoyi 框架进行系统开发的过程中,我们常常会遇到各种有趣且具有挑战性的问题。今天,我们就来深入探讨一个在实际开发中较为常见的问题:当连续快速发送 Post 请求时,前端会弹出 “数据正在处理,请…

瑞芯微方案:RV1126定制开发板方案定制

产品简介 RV1126 核心板是常州海图电子科技有限公司推出的一款以瑞芯微 RV1126处理器为核心的通用产品,其丰富的设计资源、稳定的产品性能、强力的设计支持,为客户二次开发快速转化产品提供强有力的技术保障。RV1126 核心板集多种优势于一身&#xff0c…

VB6.0 显示越南语字符

近期接到客户咨询,说是VB6.0写软件界面上显示越南语乱码,需要看看怎样解决。 我在自己电脑上也试了下,确实显示越南语结果是乱码。编辑器里乱码,运行起来界面上也是乱码。 经过一天的折腾,算是解决了问题&#xff0c…

理解C++中的右值引用

右值引用,顾名思义,就是对一个右值进行引用,或者说给右值一个别名。右值引用的规则和左值一用一模一样,都是对一个值或者对象起个别名。 1. 右值引用和左值引用一样,在定义的同时必须立即赋值,如果不立即赋…

unity.NavMesh Agent

这张图片展示的是Unity中 NavMesh Agent 组件的参数设置。NavMesh Agent 是Unity中用于实现角色自动寻路和移动的组件。下面我会通俗易懂地讲解这些参数的作用: 1. Agent Type(代理类型) 作用:定义代理的类型,比如是人…

83,【7】BUUCTF WEB [MRCTF2020]你传你[特殊字符]呢

进入靶场 图片上这个人和另一道题上的人长得好像 54,【4】BUUCTF WEB GYCTF2020Ezsqli-CSDN博客 让我们上传文件 桌面有啥传啥 /var/www/html/upload/344434f245b7ac3a4fae0a6342d1f94a/123.php.jpg 成功后我就去用蚁剑连了,连不上 看了别的wp知需要…

自签证书的dockerfile中from命令无法拉取镜像而docker的pull命令能拉取镜像

问题现象: docker pull images拉取镜像正常 dockerfile中的from命令拉取镜像就会报出证书错误。报错信息如下: [bjxtbwj-kvm-test-jenkins-6-243 ceshi_dockerfile]$ docker build . [] Building 0.4s (3/3) FINISHED …

在小红书挖掘信息的实践之旅(第一部分)

摘要 在信息爆炸时代,小红书承载大量用户真实生活分享,蕴含未挖掘价值。作者因日常观察到朋友常依赖小红书经验分享,决定尝试挖掘其中信息。在实践初期,受 DeepSeek 建议启发,确定 “以关键词为线索,层层递…

智慧消防营区一体化安全管控 2024 年度深度剖析与展望

在 2024 年,智慧消防营区一体化安全管控领域取得了令人瞩目的进展,成为保障营区安全稳定运行的关键力量。这一年,行业在政策驱动、技术创新应用、实践成果及合作交流等方面呈现出多元且深刻的发展态势,同时也面临着一系列亟待解决…

粒子群算法 笔记 数学建模

引入: 如何找到全局最大值:如果只是贪心的话,容易被局部最大解锁定 方法有:盲目搜索,启发式搜索 盲目搜索:枚举法和蒙特卡洛模拟,但是样例太多花费巨量时间 所以启发式算法就来了,通过经验和规…

从ai产品推荐到利用cursor快速掌握一个开源项目再到langchain手搓一个Text2Sql agent

目录 0. 经验分享:产品推荐 1. 经验分享:提示词优化 2. 经验分享:使用cursor 阅读一篇文章 3. 经验分享:使用cursor 阅读一个完全陌生的开源项目 4. 经验分享:手搓一个text2sql agent (使用langchain l…

14-6-1C++STL的list

(一)list容器的基本概念 list容器简介: 1.list是一个双向链表容器,可高效地进行插入删除元素 2.list不可以随机存取元素,所以不支持at.(pos)函数与[ ]操作符 (二)list容器头部和尾部的操作 list对象的默…

Couchbase UI: Dashboard

以下是 Couchbase UI Dashboard 页面详细介绍,包括页面布局和功能说明,帮助你更好地理解和使用。 1. 首页(Overview) 功能:提供集群的整体健康状态和性能摘要 集群状态 节点健康状况:绿色(正…

【WebRTC - STUN/TURN服务 - COTURN配置】

在WebRTC中,对于通信的两端不在同一个局域网的情况下,通信两端往往无法P2P直接连接,需要一个TURN中继服务,而中继服务可以选用 COTURN 构建。 注:COTURN 是一个开源的 TURN(Traversal Using Relays around…

基于OSAL的嵌入式裸机事件驱动框架——整体架构调度机制

参考B站up主【架构分析】嵌入式祼机事件驱动框架 感谢大佬分享 任务ID : TASK_XXX TASK_XXX 在系统中每个任务的ID是唯一的,范围是 0 to 0xFFFE,0xFFFF保留为SYS_TSK_INIT。 同时任务ID的大小也充当任务调度的优先级,ID越大&#…