更多内容:XiaoJ的知识星球
目录
- 2.6 非 Transformer 架构
- 2.6.1 状态空间模型 SSM
- 1)SSM(State Space Model)
- 2)RWKV(Receptance Weighted Key Value)
- 3)Mamba
- 2.6.2 训练时更新TTT(Test-TimeTraining)
2.6 非 Transformer 架构
Transformer 结构是当前大语言模型的主流模型架构,其具备构建灵活、易并行、易扩展等优势。
但是,Transformer并行输入的机制会导致模型规模随输入序列长度平方增长,导致其在处理长序列时面临计算瓶颈。
传统的 RNN 模型(如 GRU、LSTM 等)在处理长序列时可能难以捕捉到长期依赖关系,且面临着梯度消失或爆炸问题。
为了克服这些问题,近年来,研究者提出了两类现代 RNN 变体:
-
状态空间模型(State Space Model,SSM);
-
测试时训练(Test-Time Training,TTT)。
2.6.1 状态空间模型 SSM
状态空间模型(State Space Model,SSM)范式可以有效处理长文本的长程依赖性(Long-Range Dependencies, LRDs),并且可有效降低语言模型的计算和内存开销。
本节先介绍 SSM 范式,再介绍两种 SSM 代表性模型:RWKV 和 Mamba。
1)SSM(State Space Model)
SSM基于控制理论中的动力系统,利用一组状态变量来捕捉系统状态随时间的连续变化,这种连续时间的表示方法天然地适用于描述长时间范围内的依赖关系。
此外,SSM 还具有递归和卷积的离散化表示形式,既能在推理时通过递归更新高效处理序列数据,又能在训练时通过卷积操作捕捉全局依赖关系。
图 2.17: SSM 范式
(1). 连续型SSM:
如图,SSM 在三个随时间 t 变化的变量(u、x’、x、y)和四个可学习的矩阵(A、B、C、D)的基础上构造而成。SSM 的系统方程为:
-
x′(t) = Ax(t) + Bu(t)
-
y(t) = Cx(t) + Du(t)
该方程可视作 SSM 系统方程的连续形式,适用于对连续数据(例如音频信号、时间序列)的处理,但是在训练和推理都非常慢。
为了提高对 SSM 的处理效率,需要对该方程进行 离散化(Discretization) 操作。
(2). 递归型SSM:
SSM 系统方程离散化,原理是将定义在特定区间上的函数曲线下的区域视为梯形,并利用梯形面积公式计算其面积。由此,可以得出离散化后递归形式下的系统方程:
x k = A ˉ x k − 1 + B ˉ u k x_k = \bar{\mathbf{A}} x_{k-1} + \bar{\mathbf{B}} u_k xk=Aˉxk−1+Bˉuk
y k = C ˉ x k y_k = \bar{\mathbf{C}} x_k yk=Cˉxk
在该方程中,状态方程由前一步的状态和当前输入计算当前状态,体现了递归的思想。其中, A ˉ \bar{\mathbf{A}} Aˉ , B ˉ \bar{\mathbf{B}} Bˉ, C ˉ \bar{\mathbf{C}} Cˉ 为离散形式下的矩阵,其与连续形式下矩阵 A, B, C 的关系分别表示为: A ˉ = ( I − Δ 2 A ) − 1 ( I + Δ 2 A ) \bar{\mathbf{A}} = (I - \frac{\Delta}{2} \mathbf{A})^{-1} (I + \frac{\Delta}{2} \mathbf{A}) Aˉ=(I−2ΔA)−1(I+2ΔA), B ˉ = ( I − Δ 2 A ) − 1 Δ B \bar{\mathbf{B}} = (I - \frac{\Delta}{2} \mathbf{A})^{-1} \Delta \mathbf{B} Bˉ=(I−2ΔA)−1ΔB, C ˉ = C \bar{\mathbf{C}} = \mathbf{C} Cˉ=C,其中 Δ = t n + 1 − t n \Delta = t_{n+1} - t_n Δ=tn+1−tn,I 是单位矩阵。
递归形式的 SSM 类似于 RNN,具有 RNN 的优缺点。
(3). 卷积型SSM:
递归形式的 SSM 类似于 RNN,具有 RNN 的优缺点。其适用于顺序数据的处理,能够实现与序列长度呈线性复杂度的高效推理,但是无法并行训练,当序列时存在梯度消失或爆炸问题。
将系统方程的递归形式进行迭代,可以得到卷积形式。
x k = A ˉ k B ˉ u 0 + A ˉ k − 1 B ˉ u 1 + ⋯ + B ˉ u k x_k = \bar{\mathbf{A}}^k \bar{\mathbf{B}} u_0 + \bar{\mathbf{A}}^{k-1} \bar{\mathbf{B}} u_1 + \cdots + \bar{\mathbf{B}} u_k xk=AˉkBˉu0+Aˉk−1Bˉu1+⋯+Bˉuk
y k = C ˉ x k = C ˉ A ˉ k B ˉ u 0 + C ˉ A ˉ k − 1 B ˉ u 1 + ⋯ + C ˉ B ˉ u k y_k = \bar{\mathbf{C}} x_k = \bar{\mathbf{C}} \bar{\mathbf{A}}^k \bar{\mathbf{B}} u_0 + \bar{\mathbf{C}} \bar{\mathbf{A}}^{k-1} \bar{\mathbf{B}} u_1 + \cdots + \bar{\mathbf{C}} \bar{\mathbf{B}} u_k yk=Cˉxk=CˉAˉkBˉu0+CˉAˉk−1Bˉu1+⋯+CˉBˉuk
可以观察到,将系统方程的递归形式迭代展开后,输出 yk 是状态输入 uk 的卷积结果,其卷积核为:
K ˉ k = ( C ˉ B ˉ , C ˉ A ˉ B ˉ , … , C ˉ A ˉ k B ˉ ) \bar{\mathbf{K}}_k = (\bar{\mathbf{C}} \bar{\mathbf{B}}, \bar{\mathbf{C}} \bar{\mathbf{A}} \bar{\mathbf{B}}, \ldots, \bar{\mathbf{C}} \bar{\mathbf{A}}^{k} \bar{\mathbf{B}}) Kˉk=(CˉBˉ,CˉAˉBˉ,…,CˉAˉkBˉ)
因此,SSM 系统方程的卷积形式为:
y k = K ˉ k ∗ u k y_k = \bar{\mathbf{K}}_k * u_k yk=Kˉk∗uk
其中,卷积核是由 SSM 中的矩阵参数决定的,这些参数在整个序列处理过程是固定的,被称为时不变性。
时不变性使得 SSM 能够一致地处理不同时间步长的数据,进行高效的并行化训练。但由于上下文长度固定,卷积形式的 SSM 在进行自回归任务时延迟长且计算消耗大。
结合 SSM 的递归和卷积形式的优缺点,可以选择在训练时使用卷积形式,推理时使用递归形式。
综上,SSM 架构的系统方程具有三种形式,分别为连续形式、离散化的递归形式以及离散化的卷积形式,可应用于文本、视觉、音频和时间序列等任务,在应用时,需要根据具体情况选择合适的表示形式。
SSM 的优势在于能够处理非常长的序列,虽然比其它模型参数更少,但在处理长序列时仍然可以保持较快的速度。
RWKV 和 Mamba 是两种基于 SSM 范式的经典架构,下面将分别介绍这两种架构。
.
2)RWKV(Receptance Weighted Key Value)
(注:这里讨论的是 RWKV-v4)
RWKV(Receptance Weighted Key Value):
-
是基于 SSM 范式的创新架构,其核心机制 WKV 的计算可以看作是两个 SSM 的比。
-
核心模块:时间混合模块、通道混合模块。
-
基本元素:接收向量R、权重W、键向量K、值向量V。
RWKV 的设计结合了 RNNs 和 Transformers 的优点,既保留了推理阶段的高效性,又实现了训练阶段的并行化。
计算流程和作用:
图 2.18: RWKV 架构
时间混合模块和通道混合模块中共有的操作是 Token 位移,该步通过对当前时间步和前一时间步的输入进行线性插值来实现,从而确保了模型对序列中时间变化的敏感性。
在时间混合模块中:
-
接受向量 R 负责接收并整合来自序列历史的信息,权重 W 表示位置权重衰减,键向量 K 和值向量 V 类似传统注意力机制中的键和值,分别用于匹配和携带信息。
-
时间混合模块首先将当前步长和前一步长的输入进行线性组合,通过线性投影得到 R、K、V 向量;
-
随后,通过 WKV 机制来确保每个通道的权重随时间推移逐步衰减;
-
最后,将表示过去信息的 σ® 和表示当前信息的 WKV 向量通过输出门控进行整合,传递给通道混合模块。
RWKV的WKV机制是其核心,使用与通道相关的时间衰减向量W,可表示为 w t , i = − ( t − i ) w w_t,i=−(t−i)w wt,i=−(t−i)w,其中i是从当前时间步长t向后追溯的某个时间步长,w是一个非负向量,长度为通道数。
WKV机制通过线性时不变递归更新隐藏状态,可视为两个SSM之比,公式如下:
w k v t = ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i ⊙ v i + e u + k t ⊙ v t ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i + e u + k t wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i) w + k_i} \odot v_i + e^{u + k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i) w + k_i} + e^{u + k_t}} wkvt=∑i=1t−1e−(t−1−i)w+ki+eu+kt∑i=1t−1e−(t−1−i)w+ki⊙vi+eu+kt⊙vt
在通道混合模块中:
-
R ′、K ′、V ′ 的作用与时间混合模块类似,R ′ 和 K ′ 同样由 输入的线性投影得到,V ′ 的更新则额外依赖于 K ′。
-
之后,将 σ(R ′ ) 和 V ′ 整合,以 实现不同通道之间的信息交互和融合。
此外,RWKV 架构还采用了时间依赖的 Softmax 操作,提高数值稳定性和梯度传播效率,以及采用层归一化来稳定梯度,防止梯度消失和爆炸。
为了进一步提升性能,RWKV 还采用了自定义 CUDA 内核、小值初始化嵌入以及自定义初始化等措施。
RWKV 通过创新的线性注意力机制,成功结合了 Transformer 和 RNN 的优势,在模型规模和性能方面取得了显著进展。
然而,在处理长距离依赖关系和复杂任务时,RWKV 仍面临一些局限性。为了解决这些问题并进一步提升长序列建模能力,研究者们提出了 Mamba 架构。
.
3)Mamba
时不变性使得SSM能够一致地处理不同时间步长的数据,进行高效的并行化 训练,但是同时也导致其处理信息密集的数据(如文本)的能力较弱。原因如下:
-
缺乏内容感知推理能力:时不变的SSM参数固定,无法根据输入内容动态调整行为,难以进行基于内容的推理。
-
无法选择性地关注或忽略输入:无法根据任务或上下文选择性地关注或忽略特定输入信息。
-
对上下文信息的利用不足:在状态更新过程中,难以有效积累和利用长距离的上下文信息。
-
难以处理离散和结构化数据:文本数据离散且结构复杂,时不变的SSM更适合处理连续、时不变的信号。
为了弥补SSM处理信息密集的数据(如文本)的能力较弱,Mamba基于SSM架构,提出了:
-
选择机制(SelectionMechanism):使模型执行基于内容的推理。
-
硬件感知算法(Hardware-awareAlgorithm):实现了在GPU上的高效计算。
从而同时保证了快速训练和推理、高质量数据生成以及长序列处理能力。
(1). 选择机制(SelectionMechanism):
Mamba 的选择机制通过动态调整离散化 SSM 中的参数 B、C、Δ 来选择需要关注的信息,使模型参数能够根据输入数据动态变化。
具体来说,Mamba 将这些参数分别转换为以下函数: s B ( x ) = Linear N ( x ) s_{\mathbf{B}}(x) = \text{Linear}_N(x) sB(x)=LinearN(x), s C ( x ) = Linear N ( x ) s_{\mathbf{C}}(x) = \text{Linear}_N(x) sC(x)=LinearN(x), s Δ ( x ) = Broadcast D ( Linear 1 ( x ) ) s_{\Delta}(x) = \text{Broadcast}_D(\text{Linear}_1(x)) sΔ(x)=BroadcastD(Linear1(x)),并采用非线性激活函数 τ Δ \tau_{\Delta} τΔ = softplus 来调节参数 ∆。其中, Linear d \text{Linear}_d Lineard 是对特征维数 d 的参数化投影, s Δ s_{\Delta} sΔ 和 τ Δ \tau_{\Delta} τΔ 函数的选择与 RNN 的门控机制相关联。参数函数化和 RNN 的门控机制解释如下:
-
s B ( x ) = Linear N ( x ) s_{\mathbf{B}}(x) = \text{Linear}_N(x) sB(x)=LinearN(x):将输入数据 x 通过一个线性投影 * Linear N \text{Linear}N LinearN*转换为与 B \mathbf{B} B 维度匹配的向量。这个线性投影的权重和偏置是可学习的,因此 ** s B ( x ) s_{\mathbf{B}}(x) sB(x) 的输出会随着输入 x 的不同而变化。
-
s C ( x ) = Linear N ( x ) s_{\mathbf{C}}(x) = \text{Linear}_N(x) sC(x)=LinearN(x):与 s B ( x ) s_{\mathbf{B}}(x) sB(x) 类似,将输入数据 x通过另一个线性投影 Linear N \text{Linear}_N LinearN 转换为与 C \mathbf{C} C 维度匹配的向量,输出同样依赖于输入 x。
-
s Δ ( x ) = Broadcast D ( Linear 1 ( x ) ) s_{\Delta}(x) = \text{Broadcast}_D(\text{Linear}_1(x)) sΔ(x)=BroadcastD(Linear1(x)):将输入数据 x 通过一个标量线性投影 Linear 1 \text{Linear}_1 Linear1 转换为一个标量,然后通过广播操作 $\text{Broadcast}_D $ 将其扩展为与 Δ \Delta Δ 维度匹配的对角矩阵。这里的标量线性投影同样依赖于输入 x 。
-
RNN 中,门控机制(如 LSTM 的输入门、遗忘门和输出门)通过动态调整门控值来控制信息的流动。
此外,Mamba还对张量形状进行调整,使模型参数具有时间维度,从而在每个时间步都有不同的值,从时间不变转变为时间变化。
(2). 硬件感知算法(Hardware-awareAlgorithm)
选择机制使模型参数变成了输入的函数,且具有时间维度,因此模型不再具备卷积操作的平移不变性和线性时不变性,从而影响其效率。
为了实现选择性SSM 模型在GPU上的高效计算,Mamba提出一种硬件感知算法,主要包括:
-
内核融合:通过减少内存I/O操作来提高速度。
-
并行扫描:利用并行化算法提高效率。
-
重计算:则在反向传播时重新计算中间状态,以减少内存需求。
在具体实现中,SSM参数从较慢的HBM加载到更快的SRAM中进行计算,最后将输出写回HBM。这既保持了高效计算,又减少了内存使用,使SSM的内存需求与优化的Transformer实现(如FlashAttention)相当。
图2.19: Mamba 架构
Mamba通过将带有选择机制的SSM模块与Transformer的前馈层相结合,形成了一个简单且同质的架构设计。
如图2.19所示,Mamba架构是由完全相同的Mamba 模块组成的递归模型,每个Mamba模块都在前馈层中插入了卷积层和带有选择机制的SSM模块,其中激活函数σ选用SiLU/Swish激活。
.
(3). 小结:
通过引入选择机制和硬件感知算法,Mamba在实际应用中展示了卓越的性能和效率,包括:
-
快速训练和推理:训练时,计算和内存需求随着序列长度线性增长,而推理时,每一步只需常数时间,不需要保存之前的所有信息。通过硬件感知算法,Mamba不仅在理论上实现了序列长度的线性扩展,而且在A100GPU上,其推理吞吐量比类似规模的Transformer提高了5倍。
-
高质量数据生成:在语言建模、基因组学、音频、合成任务等多个模态和设置上,Mamba均表现出色。在语言建模方面,Mamba-3B模型在预训练和后续评估中性能超过了两倍参数量的Transformer模型性能。
-
长序列处理能力:Mamba能够处理长达百万级别的序列长度,展示了处理长上下文时的优越性。
虽然Mamba在硬件依赖性和模型复杂度上存在一定的局限性,但是它通过引 入选择机制和硬件感知算法显著提高了处理长序列和信息密集数据的效率,展示 了在多个领域应用的巨大潜力。Mamba在多种应用上的出色表现,使其成为一种理想的通用基础模型。
2.6.2 训练时更新TTT(Test-TimeTraining)
在处理长上下文序列时,上述基于SSM范式的架构(例如RWKV和Mamba) 通过将上下文信息压缩到固定长度的隐藏状态中,成功将计算复杂度降低至线性级别,有效扩展了模型处理长上下文的能力。
随着上下文长度的持续增长,基于SSM范式的模型可能会过早出现性能饱和,主要原因:
-
隐藏状态的表达能力限制:随着上下文长度的增加,隐藏状态无法有效地捕捉和存储长上下文中所有的关键信息。
-
参数固定与上下文增长的矛盾:当上下文长度超过一定阈值时,模型的参数量不足以处理如此长的上下文,从而出现性能饱和。
-
信息压缩与丢失:在将长上下文压缩到固定长度隐藏状态的过程中,可能会丢失一些关键信息。
为了解决这一限制,测试时训练(Test-TimeTraining,TTT)范式提供了一种有效的解决方案。TTT在推理阶段会针对每一条测试数据一边循环训练一边推理。
图2.20: TTT 范式下的推理流程
1)TTT 范式的预训练阶段:
训练过程分为内部循环和外部循环:
-
外部循环:执行传统的下词预测任务,通过自回归方式优化模型的全局权重参数。
-
内部循环:基于自监督优化隐藏状态。在每个时间步,模型根据当前输入 x t x_t xt 和先前隐藏状态 W t − 1 W_{t−1} Wt−1 计算重构损失,并利用该损失进行梯度下降更新隐藏状态,最终基于更新后的隐藏状态和当前输入生成输出。
重构损失计算公式:
ℓ ( W t − 1 , x t ) = ∥ f ( θ K x t ; W t − 1 ) − θ V x t ∥ 2 \ell\left(W_{t-1}, x_t\right) = \left\|f\left(\theta_{K} x_t; W_{t-1}\right) - \theta_{V} x_t\right\|^2 ℓ(Wt−1,xt)=∥f(θKxt;Wt−1)−θVxt∥2
隐藏状态更新公式:
W t = W t − 1 − η ∇ ℓ ( W t − 1 , x t ) W_t = W_{t-1} - \eta \nabla \ell\left(W_{t-1}, x_t\right) Wt=Wt−1−η∇ℓ(Wt−1,xt)
输出生成公式:
z t = f ( x t ; W t ) z_t = f(x_t; W_t) zt=f(xt;Wt)
2)TTT 范式的推理阶段:
推理阶段只需执行内部循环来更新隐藏状态,以适应新数据分布并提升预测性能。
3)小结:
与Transformer 相比,基于TTT范式的模型具有线性时间复杂度,这对于处理长序列数据至关重要。
相较于基于SSM的RWKV和Mamba架构,TTT通过模型参数来保存上下文信息,能够更有效地捕捉超长上下文中的语义联系和结构信息。
因此,TTT在长上下文建模任务中展现出卓越的性能,特别是在需要处理超长上 下文的应用场景中。
未来,TTT范式有望在超长序列处理任务中发挥重要作用。
.
其他参考:【大模型基础_毛玉仁】系列文章
声明:资源可能存在第三方来源,若有侵权请联系删除!