- 文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens
- 1. 文章简介
- 2. 方法原理
- 1. 方法思路
- 2. Dilated Attention
- 1. 具体原理
- 2. 多头实现
- 3. 复杂度分析
- 3. 训练方法
- 3. 实验结果
- 4. 结论 & 思考
- 5. 参考链接
- 文献链接:https://arxiv.org/abs/2307.02486
1. 文章简介
这篇文章算是我司最近的一篇力作吧,即DeepNet, Foundation Transformer之后,大佬们终于还是盯上了attention layer,毕竟attention层 O ( N 2 ) O(N^2) O(N2)的计算复杂度一直是制约Transformer往长文本发展的主要原因。
想当年,像是线性化Attention的Linformer,或者以更直观的稀疏化attention的Reformer,亦或者结合局部与全局attention的Longformer,或者类似金字塔型的将长文本拆分为短文本然后各自做attention然后逐层往上的方式(不过这篇具体文章给忘了),总之当年零零碎碎有不少关于优化attention层计算量,使之可以拓展到长文本上的工作。
不过可惜的是,虽然当时大家都觉得这个方向很重要,结果以GPT3还有PALM等为代表的大模型反而从工程上发力,直接强行扩展文本长度,从头上干掉了这个问题……
这两年,感觉这方面的工作已经比较少听到了,不过我司的大佬们似乎还是重新抓出了这个方向,然后像是DeepNet那样直接干出了一个量级上碾压的工作,也是真的厉害……
2. 方法原理
1. 方法思路
LongNet的整体的一个思路其实和之前的Reformer,Linformer等一致,还是在attention层方面做文章,希望将attention layer的计算复杂度从原始的 O ( N 2 d ) O(N^2d) O(N2d)进行优化,使得其与句长 N N N呈线性关系而非平方关系,从而使得模型整体的计算复杂度得到缩减。
对于,文中提出了dilated attention的结构,成功地将attention layer的计算复杂度从 O ( N 2 d ) O(N^2d) O(N2d)降维至 O ( N d ) O(Nd) O(Nd)复杂度。
需要注意的是,这里的比较没有包含linear transformer,它虽然很早之前已经实现了 O ( N d ) O(Nd) O(Nd)复杂度的attention实现,不过貌似效果不佳,不算是主流的attention方法,因此文中弃用了linear transformer作为对照。
下面,我们就需要具体看一下Dilated Attention层的具体实现方法。
2. Dilated Attention
1. 具体原理
首先,我们给出Dilated Attention层的整体原理图如下:
具体来说,就是首先给出一个局部窗口长度 w w w和间隔距离 r r r,那么,就可以将总长为 N N N的序列拆分为 N / w N/w N/w个子序列,然后在每一个子序列当中按照间隔 r r r取出token,一共就能够取出 w / r w/r w/r个token,然后用着 w / r w/r w/r个token作为新的序列计算attention,然后把这 N / w N/w N/w个attention矩阵concat起来,就能得到一个 N × N N \times N N×N的稀疏attention矩阵。
考察对于固定的 w , r w,r w,r下的第 i i i个attention矩阵,有:
{ Q i = [ Q i w Q i w + r ⋯ Q ( i + 1 ) w − r ] K i = [ K i w K i w + r ⋯ K ( i + 1 ) w − r ] V i = [ V i w V i w + r ⋯ V ( i + 1 ) w − r ] \left\{ \begin{aligned} Q_i &= [Q_{iw} & Q_{iw+r} & \cdots & Q_{(i+1)w-r}] \\ K_i &= [K_{iw} & K_{iw+r} & \cdots & K_{(i+1)w-r}] \\ V_i &= [V_{iw} & V_{iw+r} & \cdots & V_{(i+1)w-r}] \end{aligned} \right. ⎩ ⎨ ⎧QiKiVi=[Qiw=[Kiw=[ViwQiw+rKiw+rViw+r⋯⋯⋯Q(i+1)w−r]K(i+1)w−r]V(i+1)w−r]
此时有:
O i = s o f t m a x ( Q i ⋅ K i T d ) V i O_i = \mathop{softmax}(\frac{Q_i \cdot K_i^T}{\sqrt{d}})V_i Oi=softmax(dQi⋅KiT)Vi
当然,这样的一个attention矩阵事实上只包含了局部的attention信息,因此无法兼顾长距离和短距离的attention信息。因此,如果要令总的attention兼顾长距离和短距离的attention信息,就需要取出多组 w , r w,r w,r,分别计算attention然后进行矩阵加和。也就是上图中的合并部分,从而才能获得包含全局attention信息的矩阵。
具体实现上来说,文中采用的是等比数列的方式进行实现,比如如下的方式:
{ w = w , α w , α 2 w , ⋯ , α n w r = r , α r , α 2 r , ⋯ , α n r \left\{ \begin{aligned} w &= {w, \alpha w, \alpha^2 w, \cdots, \alpha^n w} \\ r &= {r, \alpha r, \alpha^2 r, \cdots, \alpha^n r} \end{aligned} \right. {wr=w,αw,α2w,⋯,αnw=r,αr,α2r,⋯,αnr
在上图的demo中,取用的 w , r w,r w,r就是 4 4 4和 1 1 1, α \alpha α的取值为 2 2 2。
当然,考虑到由于 w , r w,r w,r取值不同导致的attention的密度不同,因此加和的时候需要对权重进行调整,具体而言:
O = ∑ i = 1 k s i ∑ j s j O r i , w i O = \sum\limits_{i=1}^{k}\frac{s_i}{\sum_j s_j}O_{r_i, w_i} O=i=1∑k∑jsjsiOri,wi
其中, s i s_i si是 ( w i , r i ) (w_i, r_i) (wi,ri)这组参数下计算得到的attention矩阵( Q i ⋅ K i T d \frac{Q_i \cdot K_i^T}{\sqrt{d}} dQi⋅KiT)在计算softmax时的分母部分,也就是:
∑ j e Q i ⋅ K i T d \sum\limits_{j} e^{\frac{Q_i \cdot K_i^T}{\sqrt{d}}} j∑edQi⋅KiT
这样也就得到了一组 n n n维的系数向量,作为我们这里的 s s s。
2. 多头实现
关于Dilated Attention的多头实现,整体来说和vanilla transformer的实现方式是一致的,还是在input的向量当中进行split,然后分别过一个上述介绍的Dilated Attention层,最后将output的结果concat起来即可。
不过,感谢作者Shuming大佬的解释,这里和vanilla transformer存在一定的区别,具体就在于对于每一个context window,我们事实上都是等间隔的sample了其中的几个token进行attention的计算,某种意义上来说总是会丢失掉一些信息的。
因此,在设计多头attention的时候,文中进行了一定的优化,即对于input的token位置在不同的head上面给了不同的位置偏移量,从而使得尽可能地覆盖更多的token之间的attention。
具体来说就是,对于第 j j j个head,选取的token为:
{ Q i = [ Q i w + j ( ≡ r ) Q i w + r + j ( ≡ r ) ⋯ Q ( i + 1 ) w − r + j ( ≡ r ) ] K i = [ K i w + j ( ≡ r ) K i w + r + j ( ≡ r ) ⋯ K ( i + 1 ) w − r + j ( ≡ r ) ] V i = [ V i w + j ( ≡ r ) V i w + r + j ( ≡ r ) ⋯ V ( i + 1 ) w − r + j ( ≡ r ) ] \left\{ \begin{aligned} Q_i &= [Q_{iw + j(\equiv r)} & Q_{iw+r + j(\equiv r)} & \cdots & Q_{(i+1)w-r + j(\equiv r)}] \\ K_i &= [K_{iw + j(\equiv r)} & K_{iw+r + j(\equiv r)} & \cdots & K_{(i+1)w-r + j(\equiv r)}] \\ V_i &= [V_{iw + j(\equiv r)} & V_{iw+r + j(\equiv r)} & \cdots & V_{(i+1)w-r + j(\equiv r)}] \end{aligned} \right. ⎩ ⎨ ⎧QiKiVi=[Qiw+j(≡r)=[Kiw+j(≡r)=[Viw+j(≡r)Qiw+r+j(≡r)Kiw+r+j(≡r)Viw+r+j(≡r)⋯⋯⋯Q(i+1)w−r+j(≡r)]K(i+1)w−r+j(≡r)]V(i+1)w−r+j(≡r)]
可以用文中的图3来对上述不同头的attention进行更为形象化的展示如下:
3. 复杂度分析
下面,我们来考察一下Dilated Attention层的算法复杂度。
我们首先来考察对于一组确定的 w , r w,r w,r对应的Dilated Attention层的算法复杂度,其对应的结果如下:
F L O P s = 2 N w ⋅ ( w r ) 2 d = 2 N w d r 2 FLOPs = \frac{2N}{w} \cdot (\frac{w}{r})^2d = \frac{2Nwd}{r^2} FLOPs=w2N⋅(rw)2d=r22Nwd
因此,遍历 w , r w,r w,r,我们即可得到完整的Dilated Attention层的算法复杂度如下:
F L O P s = ∑ i = 0 k − 1 2 N w i d r i 2 = 2 N w 0 d r 0 2 ∑ i = 0 k − 1 1 α i < 2 N w 0 d r 0 2 ⋅ α α − 1 ∼ O ( N d ) FLOPs = \sum\limits_{i=0}^{k-1}\frac{2Nw_id}{r_i^2} = \frac{2Nw_0d}{r_0^2} \sum\limits_{i=0}^{k-1} \frac{1}{\alpha^i} < \frac{2Nw_0d}{r_0^2} \cdot \frac{\alpha}{\alpha-1} \sim O(Nd) FLOPs=i=0∑k−1ri22Nwid=r022Nw0di=0∑k−1αi1<r022Nw0d⋅α−1α∼O(Nd)
3. 训练方法
最后,我们看一下文中实际的训练过程。
注意到,这里由于极限的扩展了输入的context的序列长度,因此事实上如何将文本塞入GPU也就成了一个大问题,因此,这方面也需要有一些工程上的实现细节考察。
具体来说,文中给出的方法还是说先对sequence进行一下split,然后由不同的GPU分别计算,最后进行加总实现。
其原理图可以参考文中的图4:
不过需要注意的是,这里在不同的gpu当中计算完了不同的部分的input seq之后,在计算dilated attention的时候会有一个slice的过程,然后slice之后的得到的dilated attention会在不同的GPU之间进行聚合,从而确保不同的gpu上的token之间的attention能够相互计算和聚合。
由于这里只是slice之后的attention,因此可以避免掉由于过长的文本长度(比如文中给出的1B)导致的内存爆炸的问题。
3. 实验结果
文中使用torchscale作为基准库,然后替换attention layer之后train了一个768维,12层的模型进行实验考察。
得到结果如下:
而除了最终的ppl之外,文中还比较了transformer与LongNet在处理不同文本长度的文本时所需的计算量。
可以看到:
- LongNet可以在更少的计算量下获得相较于原始的transformer更好的ppl。
此外,文中还对LongNet在不同的参数量以及不同的context window进行了一下考察,得到结果如下:
可以看到:
- 随着参数量的增长,模型的ppl是在不断减小的,说明LongNet具有很好的扩展能力;
- context window越大,模型的效果也能够不断地提升,说明LongNet对于长文本有较好的理解能力。
最后,文中还非常直观的给出了将输入文本长度扩展到1B之后vanilla transformer与LongNet的infer时间变化的比较:
其结果直观地证明了LongNet对于长文本处理能力的能力,较之Vanilla Transformer耗时的快速增长,Dilated Attention基本没有发生什么太大的变化。
4. 结论 & 思考
综上,整体而言这篇文章还是很惊艳的,至少从context length的角度来说这种突破性的震撼确实厉害,结合他之前的foundation transformer等工作,我觉得他们在transformer的基础架构上面确实花了不少的功夫来做优化,这一点确实是厉害。
不过考虑到工程上,这篇文章的主要贡献可能还是在于长文本的关联attention上面,也就意味着其优势必然还是需要长上下文+大语料的前提下才能充分发挥出它的效果,就目前我的工作而言,可能还是有点用不太到……
所以,就只能膜拜一下大佬了,后面有机会的话可以考虑一下在业余时间复现一下看看了,在工作上倒是觉得ROI应该是不会很大了……
5. 参考链接
- Longformer: 局部Attention和全局attention的混搭