NSA(Natively Sparse Attention)论文原理解析
论文标题: Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
作者团队: DeepSeek-AI, Peking University, University of Washington
核心目标: 提出一种高效、可训练的稀疏注意力机制,以提高长文本处理的计算效率,同时保持模型性能。
1. NSA 研究背景
1.1 长文本建模的挑战
- 现代大模型(如 GPT-4, Gemini 1.5)需要处理超长文本(64k 甚至更长)。
- 传统 全注意力(Full Attention) 计算复杂度为 O(N²),在长文本上计算开销巨大,导致 训练和推理效率低下。
- 现有的稀疏注意力(Sparse Attention)方法在 训练阶段支持较弱,通常只优化推理阶段。
1.2 现有稀疏注意力方法的局限
- 理论计算减少 ≠ 真实速度提升:
- 许多方法仅在 推理阶段(Inference Stage) 优化,忽略了 训练时间的计算成本。
- 例如 H2O、Quest 关注 KV 缓存剪枝(KV-cache pruning),但 在真实硬件上加速有限。
- 训练阶段支持不足:
- 许多方法采用 离散选择(Discrete Selection),导致 梯度无法回传,难以进行端到端训练(End-to-End Training)。
2. NSA 方法:基于层次化的稀疏注意力
NSA 提出的创新点:
- 层次化稀疏策略(Hierarchical Sparse Strategy)
- 结合 粗粒度 token 压缩(Compression) 和 细粒度 token 选择(Selection),同时保留 全局信息 和 局部精度。
- 硬件优化(Hardware-Aligned System)
- 设计 适用于现代 GPU(如 A100, H100)的优化算子,提升推理效率。
- 可训练性增强(Natively Trainable Design)
- 允许在 训练阶段 进行稀疏优化,而不仅仅是在推理阶段加速。
2.1 NSA 关键机制
NSA 通过 三种注意力路径 进行计算:
- 压缩注意力(Compressed Attention)
- 通过块级 Token 压缩(Blockwise Token Compression),减少计算开销。
- 选择性注意力(Selected Attention)
- 仅保留 Top-k 重要 token,忽略不重要的 Token,提高计算效率。
- 滑动窗口注意力(Sliding Attention)
- 确保局部上下文不会丢失,提高信息完整性。
NSA 计算过程
- 查询(Query) 经过 三种注意力路径 计算 注意力得分(Attention Score)。
- 不同路径的注意力结果通过门控机制(Gating Mechanism)进行加权融合。
- 最终得到优化后的注意力输出(Sparse Attention Output)。
3. NSA 在硬件上的优化
3.1 计算强度均衡(Arithmetic Intensity Balance)
- 在现代 GPU 上,计算强度(Arithmetic Intensity)决定了性能瓶颈:
- 高计算强度(Compute-Bound):计算单元占用率高,计算能力未完全发挥。
- 低计算强度(Memory-Bound):计算单元空闲,受限于显存访问速度。
- NSA 通过 块级计算(Blockwise Computation) 提高 计算密度(Compute Density),减少显存访问瓶颈。
3.2 Triton 自定义内核(Triton Kernel Optimization)
- 传统注意力计算 内存访问不连续,GPU 计算利用率低。
- NSA 通过 基于 Triton 的自定义 GPU 内核(Custom GPU Kernel for Sparse Selection):
- 组级数据加载(Group-Centric Data Loading):避免多次访问 KV 缓存,减少内存带宽压力。
- 共享 KV 读取(Shared KV Fetching):减少重复数据加载,提高计算效率。
4. NSA 在实验中的表现
4.1 计算加速
- 相比全注意力(Full Attention),NSA 在 64k 序列上的速度提升最高可达 11.6×。
- 在训练阶段,NSA 前向传播(Forward)速度提高 9.0×,反向传播(Backward)速度提高 6.0×。
4.2 模型性能
- 在多个 自然语言任务(NLP Benchmarks) 上,NSA 在 保持甚至超过全注意力性能 的同时,大幅提高计算效率。
- 在 64k 长文本任务(LongBench Benchmark)中,NSA 超过所有现有稀疏注意力方法。
4.3 复杂推理能力
- NSA 在 数学推理任务(AIME 24 Benchmark) 中表现出色:
- 在 8k 和 16k 上下文长度下,NSA 比全注意力基线提高 2.5× 和 1.6×。
5. NSA 的关键优势
特点 | NSA 贡献 |
---|---|
计算复杂度降低 | 通过 层次化稀疏选择,将 O(N²) 降至 O(N log K)。 |
硬件优化 | 适配 GPU Tensor Cores,优化内存访问,提高计算效率。 |
训练支持 | NSA 可训练(Natively Trainable),不同于只优化推理的稀疏方法。 |
长文本处理能力 | 在 64k 长文本任务上超越全注意力,同时加速 推理和训练。 |
6. 论文总结
NSA 通过 层次化稀疏注意力、硬件优化、训练可行性,在 计算加速和性能保持之间取得了平衡。
相较于现有方法,NSA 不仅优化了推理(Inference),还显著降低了训练(Training)计算成本,为长文本建模提供了新的解决方案。
压缩注意力(Compressed Attention)机制解析
目标:
- 在保持全局信息的同时 降低计算复杂度,减少 Query-Key 计算量。
- 通过 块级(blockwise)token 聚合,减少注意力计算中需要处理的 Key-Value 数量。
1. 为什么需要压缩注意力?
- 标准注意力机制:每个 Query q q q 需要计算所有 Key K K K 的注意力分数,计算复杂度为 O ( N 2 ) O(N^2) O(N2)。
- 稀疏注意力(Sparse Attention):可以减少部分 Query-Key 计算,但仍然面临计算量和显存占用的问题。
- 压缩注意力(Compressed Attention) 通过 对 Key-Value 进行块级聚合,减少 Key-Value 数量,降低计算复杂度。
2. 压缩注意力的具体方法
NSA 采用 块级 token 聚合 的方式,将 Key-Value 压缩成更少的代表性 token。
这一过程可以分为 四步:
2.1. 按块划分 Key-Value
- 设 输入序列长度为 T T T,Key-Value 维度为 d k d_k dk(Key 维度)和 d v d_v dv(Value 维度)。
- 选择 块大小(block size) l l l,把 Key-Value 分成多个块:
- 第 i i i 块的 Key 表示为:
K i = { k i ⋅ l , k i ⋅ l + 1 , … , k ( i + 1 ) ⋅ l − 1 } K_i = \{ k_{i \cdot l}, k_{i \cdot l+1}, \dots, k_{(i+1) \cdot l - 1} \} Ki={ki⋅l,ki⋅l+1,…,k(i+1)⋅l−1}
- 第 i i i 块的 Value 表示为:
V i = { v i ⋅ l , v i ⋅ l + 1 , … , v ( i + 1 ) ⋅ l − 1 } V_i = \{ v_{i \cdot l}, v_{i \cdot l+1}, \dots, v_{(i+1) \cdot l - 1} \} Vi={vi⋅l,vi⋅l+1,…,v(i+1)⋅l−1}
- 这样,原始 Key-Value 变成了 块级 Key-Value,大幅减少了 Key 的数量。
2.2. 计算块级 Key 的代表性
- 块级 Key K cmp K_{\text{cmp}} Kcmp 需要能够代表整个块的信息,可以用 平均池化(Mean Pooling) 或 可训练 MLP:
- 平均池化(Mean Pooling):
K cmp , i = 1 l ∑ j = 0 l − 1 K i ⋅ l + j K_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} K_{i \cdot l + j} Kcmp,i=l1j=0∑l−1Ki⋅l+j
- 可训练 MLP(Multi-Layer Perceptron):
K cmp , i = MLP ( K i ⋅ l : ( i + 1 ) ⋅ l ) K_{\text{cmp}, i} = \text{MLP}(K_{i \cdot l : (i+1) \cdot l}) Kcmp,i=MLP(Ki⋅l:(i+1)⋅l)
- 其中 MLP 可以学习更丰富的特征,而平均池化计算量更低。
2.3. 计算块级 Value
- 块级 Value V cmp V_{\text{cmp}} Vcmp 也可以采用类似方法:
- 平均池化:
V cmp , i = 1 l ∑ j = 0 l − 1 V i ⋅ l + j V_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} V_{i \cdot l + j} Vcmp,i=l1j=0∑l−1Vi⋅l+j
- 或使用 MLP:
V cmp , i = MLP ( V i ⋅ l : ( i + 1 ) ⋅ l ) V_{\text{cmp}, i} = \text{MLP}(V_{i \cdot l : (i+1) \cdot l}) Vcmp,i=MLP(Vi⋅l:(i+1)⋅l)
- 这样可以降低计算量,同时保留重要信息。
2.4. 使用压缩 Key-Value 计算注意力
- 计算 Query Q Q Q 和压缩后的 Key K cmp K_{\text{cmp}} Kcmp 之间的注意力:
A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dkQKcmpT
- 计算 Softmax:
A cmp ′ = Softmax ( A cmp ) A'_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) Acmp′=Softmax(Acmp)
- 计算最终的注意力输出:
O cmp = A cmp ′ V cmp O_{\text{cmp}} = A'_{\text{cmp}} V_{\text{cmp}} Ocmp=Acmp′Vcmp
3. 压缩注意力的优势
对比项 | 普通注意力 | 稀疏注意力(Sparse Attention) | 压缩注意力(Compressed Attention) |
---|---|---|---|
计算复杂度 | O ( N 2 ) O(N^2) O(N2) | O ( N log k ) O(N \log k) O(Nlogk) | O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(( M \ll N \)) |
信息保留 | 完整信息 | 仅保留 Top-k 信息 | 保留全局信息,同时减少计算量 |
适用场景 | 短文本 | 长文本,但计算仍然较大 | 适合超长文本(64k+),计算高效 |
- 相比全注意力(Full Attention),压缩注意力减少了计算量。
- 相比其他稀疏注意力方法,压缩注意力能保留更多全局信息,同时具有更好的计算效率。
4. 代码示例
这里是一个 PyTorch 实现的 压缩注意力:
import torch
import torch.nn as nnclass CompressedAttention(nn.Module):def __init__(self, embed_dim, num_heads, block_size=32):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.block_size = block_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.size() # Batch, Sequence Length, Embedding Dimension# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Blockwise compression (mean pooling)num_blocks = T // self.block_sizeK_cmp = K.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)V_cmp = V.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)# Compute attention with compressed keysattn_weights = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_weights = torch.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, V_cmp)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension
num_heads = 8
block_size = 16attention = CompressedAttention(C, num_heads, block_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape) # (B, T, C)
5. 总结
- 压缩注意力(Compressed Attention) 通过 块级聚合 Key-Value,大幅降低计算量,同时保留全局信息。
- 计算复杂度降低: O ( N 2 ) → O ( N ⋅ M ) O(N^2) \to O(N \cdot M) O(N2)→O(N⋅M),其中 M ≪ N M \ll N M≪N(压缩后的块数)。
- 适用于超长文本建模,在 64k 甚至更长的序列 上能够高效工作。
- 硬件友好,支持 GPU Tensor Core 优化,减少显存占用。
选择性注意力(Selected Attention)机制解析
目标:
- 选择 最重要的 Key-Value 进行计算,而不是对所有 Key 计算注意力,从而降低计算复杂度。
- 通过 Top-K 选择策略,保留最关键的信息,减少冗余计算,提高长序列建模能力。
1. 为什么需要选择性注意力?
- 普通注意力(Full Attention) 计算复杂度 O(N²),当序列长度很长(如 64k+),计算量巨大。
- 压缩注意力(Compressed Attention) 通过 块级聚合 降低计算量,但可能损失部分细节信息。
- 选择性注意力(Selected Attention) 进一步优化,只保留最重要的 Token 参与计算,避免处理不重要的信息,减少计算开销,同时保持全局和局部信息。
2. 选择性注意力的核心步骤
NSA 采用 基于注意力得分的动态 Top-K 选择(Top-K Token Selection) 方法来筛选关键 Token:
2.1. 计算 Query-Key 相关性
首先,计算 查询(Query) 和 所有键(Key) 的相似性(即注意力分数):
A = Q K T d k A = \frac{Q K^T}{\sqrt{d_k}} A=dkQKT
其中:
- A A A 是注意力分数矩阵,形状为 ( B , H , T , T ) (B, H, T, T) (B,H,T,T),表示每个 Query 对应 Key 的注意力得分。
2.2. 选择 Top-K 重要 Token
- 对于每个 Query,选择 Top-K 重要的 Key,其余的 Key 设为 − ∞ -\infty −∞(即被 Mask)。
- 具体实现:
- 计算每个 Query 对所有 Key 的注意力分数。
- 使用 Top-K 算法 找出最大的 K K K 个值,索引存入 I top-k I_{\text{top-k}} Itop-k:
I top-k = argtopk ( A , K ) I_{\text{top-k}} = \text{argtopk}(A, K) Itop-k=argtopk(A,K)
- 构造稀疏化的注意力分数矩阵:
A i j ′ = { A i j , j ∈ I top-k ( i ) − ∞ , 否则 A'_{ij} = \begin{cases} A_{ij}, & j \in I_{\text{top-k}}(i) \\ -\infty, & \text{否则} \end{cases} Aij′={Aij,−∞,j∈Itop-k(i)否则
- 这样,我们 只在最重要的 Top-K Token 上计算 Softmax:
A ~ = Softmax ( A ′ ) \tilde{A} = \text{Softmax}(A') A~=Softmax(A′)
2.3. 计算注意力输出
最终,用选择的 Top-K 注意力分数 计算新的 Value 权重求和:
O = A ~ V O = \tilde{A} V O=A~V
这样,Query 只会与 最相关的 Key-Value 交互,提高计算效率,同时保留重要信息。
3. 选择性注意力的优势
方法 | 计算复杂度 | 信息保留能力 | 适用场景 |
---|---|---|---|
全注意力(Full Attention) | O ( N 2 ) O(N^2) O(N2) | 完整 | 适用于短文本 |
压缩注意力(Compressed Attention) | O ( N ⋅ M ) O(N \cdot M) O(N⋅M) | 保留全局信息 | 适用于长文本 |
选择性注意力(Selected Attention) | O ( N ⋅ K ) O(N \cdot K) O(N⋅K) | 只保留最重要信息 | 适用于超长文本(64k+) |
- 相比全注意力(Full Attention),选择性注意力只计算 Top-K 重要信息,大幅降低计算量。
- 相比压缩注意力(Compressed Attention),选择性注意力能保留 更精确的局部信息,保证高精度。
4. PyTorch 实现
以下是 选择性注意力 的 PyTorch 代码:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelectedAttention(nn.Module):def __init__(self, embed_dim, num_heads, top_k):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.top_k = top_kself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, d_k)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Compute attention scoresattn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, H, T, T)# Select Top-K tokenstopk_values, topk_indices = torch.topk(attn_scores, self.top_k, dim=-1) # (B, H, T, K)# Create a mask for non-Top-K elementsmask = torch.full_like(attn_scores, float('-inf')) # Default maskmask.scatter_(-1, topk_indices, topk_values) # Retain Top-K values# Apply softmax on selected tokensattn_weights = F.softmax(mask, dim=-1)# Compute attention outputattn_output = torch.matmul(attn_weights, V) # (B, H, T, d_k)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension
num_heads = 8
top_k = 16 # 选择 Top-K 重要 Tokenattention = SelectedAttention(C, num_heads, top_k)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape) # (B, T, C)
5. 选择性注意力的优化方向
- Top-K 选择的优化
- 目前使用
torch.topk()
进行选择,时间复杂度为 O ( N log K ) O(N \log K) O(NlogK)。 - 可以优化为 Heap Sort + 近似选择算法,进一步提高效率。
- 目前使用
- 自适应 K 值选择
- 目前的 K 值是固定的,可以使用 Learnable Gate 机制,让模型 动态决定 K 的大小。
- 结合其他稀疏注意力
- 压缩注意力 + 选择性注意力 可以同时 减少计算量 和 保留最关键信息,适合超长序列任务(64k+)。
6. 总结
- 选择性注意力(Selected Attention) 通过 Top-K 选择 只保留最重要的 Key-Value,降低计算量。
- 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ K ) O(N \cdot K) O(N⋅K),适用于 超长文本(64k+)。
- 相较于全注意力和压缩注意力,选择性注意力能更精准地保留信息,同时减少计算成本。
- 可进一步优化 通过 更快的 Top-K 选择算法 或 自适应 K 值选择 提升性能。
滑动窗口注意力(Sliding Attention)机制解析
目标:
- 在减少计算量的同时,保留局部上下文信息,确保模型能够感知短期依赖关系。
- 结合 压缩注意力(Compressed Attention) 和 选择性注意力(Selected Attention),在局部窗口范围内保留完整的注意力计算,避免远程信息丢失。
1. 为什么需要滑动窗口注意力?
- 全注意力(Full Attention) 计算复杂度 O(N²),长序列(64k+)下计算成本极高。
- 压缩注意力(Compressed Attention) 关注全局信息,但可能会丢失局部细节。
- 选择性注意力(Selected Attention) 关注最关键的信息,但可能无法保留局部语境。
- 滑动窗口注意力(Sliding Attention) 通过局部窗口机制,确保模型可以关注最近的信息,同时减少计算量。
2. 滑动窗口注意力的核心步骤
NSA 采用 基于局部窗口的注意力计算(Local Context Attention),主要分为 四步:
2.1. 定义窗口范围
- 设序列长度为 T T T,窗口大小设定为 W W W(window size),则对于每个 Query Q i Q_i Qi,它只会计算:
K win , i = { k i − W , k i − W + 1 , … , k i } K_{\text{win}, i} = \{ k_{i-W}, k_{i-W+1}, \dots, k_i \} Kwin,i={ki−W,ki−W+1,…,ki}
V win , i = { v i − W , v i − W + 1 , … , v i } V_{\text{win}, i} = \{ v_{i-W}, v_{i-W+1}, \dots, v_i \} Vwin,i={vi−W,vi−W+1,…,vi}
- 窗口只包含最近 W W W 个 Token,降低计算复杂度。
- 可变窗口机制:可根据任务需求设定不同的窗口大小(例如代码生成任务可能需要更大的窗口)。
2.2. 计算窗口内的 Query-Key 注意力
- 在窗口范围 W W W 内计算标准注意力:
A win , i = Q i K win , i T d k A_{\text{win}, i} = \frac{Q_i K_{\text{win}, i}^T}{\sqrt{d_k}} Awin,i=dkQiKwin,iT
- 相比于全局注意力(O(N²)),窗口内计算量为 O(N × W),显著降低复杂度。
- 仅关注 最近 W W W 个 Token,保证短期依赖关系。
2.3. 计算 Softmax 并加权求和
- 计算窗口内的注意力分布:
A win , i ′ = Softmax ( A win , i ) A'_{\text{win}, i} = \text{Softmax}(A_{\text{win}, i}) Awin,i′=Softmax(Awin,i)
- 计算最终的注意力输出:
O win , i = A win , i ′ V win , i O_{\text{win}, i} = A'_{\text{win}, i} V_{\text{win}, i} Owin,i=Awin,i′Vwin,i
2.4. 结合其他注意力机制
- 最终输出:
O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin
- g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是可学习的门控参数(Gating Mechanism)。
- 这样可以在训练过程中,让模型学习最佳的注意力组合方式。
3. 滑动窗口注意力的优势
方法 | 计算复杂度 | 局部信息保留 | 适用场景 |
---|---|---|---|
全注意力(Full Attention) | O ( N 2 ) O(N^2) O(N2) | ✅ 完整 | 适用于短文本 |
压缩注意力(Compressed Attention) | O ( N ⋅ M ) O(N \cdot M) O(N⋅M) | ⚠️ 可能丢失局部信息 | 适用于长文本 |
选择性注意力(Selected Attention) | O ( N ⋅ K ) O(N \cdot K) O(N⋅K) | ⚠️ 仅保留关键 Token | 适用于超长文本 |
滑动窗口注意力(Sliding Attention) | O ( N ⋅ W ) O(N \cdot W) O(N⋅W) | ✅ 重点保留局部信息 | 适用于超长文本(64k+) |
- 相比全注意力(Full Attention),滑动窗口注意力 显著减少计算量。
- 相比压缩注意力(Compressed Attention),滑动窗口注意力确保局部信息不会丢失。
- 相比选择性注意力(Selected Attention),滑动窗口注意力不会忽略短期依赖。
4. PyTorch 实现
以下是 滑动窗口注意力(Sliding Attention) 的 PyTorch 代码:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SlidingWindowAttention(nn.Module):def __init__(self, embed_dim, num_heads, window_size):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.window_size = window_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, d_k)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Initialize attention scores (masked)attn_scores = torch.full((B, self.num_heads, T, T), float('-inf'), device=query.device)# Apply sliding window maskfor i in range(T):start_idx = max(0, i - self.window_size)attn_scores[:, :, i, start_idx:i+1] = torch.matmul(Q[:, :, i:i+1, :], K[:, :, start_idx:i+1, :].transpose(-2, -1)) / (self.head_dim ** 0.5)# Compute attention with masked softmaxattn_weights = F.softmax(attn_scores, dim=-1)# Compute attention outputattn_output = torch.matmul(attn_weights, V) # (B, H, T, d_k)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension
num_heads = 8
window_size = 16 # 滑动窗口大小attention = SlidingWindowAttention(C, num_heads, window_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape) # (B, T, C)
5. 进一步优化方向
- 动态窗口大小
- 当前的 窗口大小 W W W 是固定的,可以使用 自适应机制(Adaptive Window Size) 让模型学习最佳的窗口大小。
- 结合 FlashAttention 提高计算效率
- 目前的 滑动窗口计算仍然需要遍历 Query,可以优化成 块级计算(Blockwise Computation),提升 GPU 利用率。
6. 总结
- 滑动窗口注意力(Sliding Attention) 通过 局部窗口计算,减少计算量,同时保留最近的上下文信息。
- 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ W ) O(N \cdot W) O(N⋅W),适用于 超长文本(64k+)。
- 结合其他注意力机制(压缩 + 选择性 + 滑动窗口)可以 提高计算效率,同时保留全局 + 局部信息。
NSA 论文中如何结合三种注意力机制?
在 Natively Sparse Attention(NSA) 机制中,作者采用了一种 层次化稀疏注意力策略(Hierarchical Sparse Strategy),将 压缩注意力(Compressed Attention)、选择性注意力(Selected Attention)和滑动窗口注意力(Sliding Attention) 结合,以 同时保留全局信息、关键 Token 以及局部信息,提高计算效率并优化长序列建模。
1. NSA 采用的三条注意力路径
NSA 通过以下三种不同的注意力计算路径,让 Transformer 既能高效处理长序列,又不会丢失关键信息:
- 压缩注意力(Compressed Attention)
- 作用:全局信息提取
- 方式:将 Key-Value 进行 块级压缩,生成粗粒度的全局 Token 表示。
- 计算复杂度: O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(其中 M ≪ N M \ll N M≪N)。
- 选择性注意力(Selected Attention)
- 作用:筛选最关键的 Token 进行计算
- 方式:对所有 Query 计算注意力分数,并选择 Top-K 重要 Token,仅对这些 Key 计算注意力。
- 计算复杂度: O ( N ⋅ K ) O(N \cdot K) O(N⋅K)(其中 K ≪ N K \ll N K≪N)。
- 滑动窗口注意力(Sliding Attention)
- 作用:局部上下文信息保留
- 方式:每个 Query 仅在其 最近的 W W W 个 Token 内 计算注意力,保留短期依赖信息。
- 计算复杂度: O ( N ⋅ W ) O(N \cdot W) O(N⋅W)(其中 W ≪ N W \ll N W≪N)。
最终的注意力输出是三种机制的加权和:
O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin
其中:
- g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是 可学习的门控参数(Gating Mechanism),用于控制不同注意力机制的重要性。
2. NSA 具体如何组合这三种注意力?
(1) 计算 Query-Key 相关性
首先,对 Query 计算三种不同 Key 形式的注意力分数:
- 压缩 Key( K cmp K_{\text{cmp}} Kcmp):计算 Query 和 压缩后的 Key 的相关性:
A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dkQKcmpT
2. 选择性 Key( K sel K_{\text{sel}} Ksel):计算 Query 和 Top-K 选择的 Key 的相关性:
A sel = Q K sel T d k A_{\text{sel}} = \frac{Q K_{\text{sel}}^T}{\sqrt{d_k}} Asel=dkQKselT
3. 滑动窗口 Key( K win K_{\text{win}} Kwin):计算 Query 在 局部窗口范围内 的注意力:
A win = Q K win T d k A_{\text{win}} = \frac{Q K_{\text{win}}^T}{\sqrt{d_k}} Awin=dkQKwinT
(2) 计算 Softmax 归一化
对每个注意力分数进行 Softmax 计算:
A ~ cmp = Softmax ( A cmp ) \tilde{A}_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) A~cmp=Softmax(Acmp)
A ~ sel = Softmax ( A sel ) \tilde{A}_{\text{sel}} = \text{Softmax}(A_{\text{sel}}) A~sel=Softmax(Asel)
A ~ win = Softmax ( A win ) \tilde{A}_{\text{win}} = \text{Softmax}(A_{\text{win}}) A~win=Softmax(Awin)
(3) 计算注意力输出
计算不同注意力的加权求和:
O cmp = A ~ cmp V cmp O_{\text{cmp}} = \tilde{A}_{\text{cmp}} V_{\text{cmp}} Ocmp=A~cmpVcmp
O sel = A ~ sel V sel O_{\text{sel}} = \tilde{A}_{\text{sel}} V_{\text{sel}} Osel=A~selVsel
O win = A ~ win V win O_{\text{win}} = \tilde{A}_{\text{win}} V_{\text{win}} Owin=A~winVwin
(4) 加权融合不同注意力结果
最终的输出由三种注意力结果加权融合:
O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin
其中:
- g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是 可学习的门控参数,通过 MLP 计算:
g = σ ( MLP ( X ) ) g = \sigma(\text{MLP}(X)) g=σ(MLP(X))
其中 σ \sigma σ 是 Sigmoid 激活函数,确保 g g g 取值在 (0,1) 之间。
3. PyTorch 实现
以下是 结合三种注意力的 NSA 模型:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass NSA(nn.Module):def __init__(self, embed_dim, num_heads, top_k, window_size):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.top_k = top_kself.window_size = window_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)self.gate_mlp = nn.Sequential(nn.Linear(embed_dim, 3), nn.Sigmoid()) # 生成3个门控权重def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# 计算门控权重gate_weights = self.gate_mlp(query).unsqueeze(-1).unsqueeze(-1) # (B, T, 3) -> (B, T, 3, 1, 1)# 压缩注意力K_cmp = K.mean(dim=-2, keepdim=True)V_cmp = V.mean(dim=-2, keepdim=True)attn_cmp = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)O_cmp = torch.matmul(F.softmax(attn_cmp, dim=-1), V_cmp)# 选择性注意力topk_values, topk_indices = torch.topk(attn_cmp, self.top_k, dim=-1)attn_sel = torch.zeros_like(attn_cmp).scatter_(-1, topk_indices, topk_values)O_sel = torch.matmul(F.softmax(attn_sel, dim=-1), V)# 滑动窗口注意力attn_win = attn_cmp.masked_fill(torch.arange(T)[:, None] < (torch.arange(T) - self.window_size), float('-inf'))O_win = torch.matmul(F.softmax(attn_win, dim=-1), V)# 加权求和O = gate_weights[..., 0] * O_cmp + gate_weights[..., 1] * O_sel + gate_weights[..., 2] * O_winO = O.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(O)# 测试
B, T, C = 2, 64, 128
attention = NSA(C, num_heads=8, top_k=16, window_size=16)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)
output = attention(query, key, value)
print(output.shape) # (B, T, C)
总结
- NSA 通过三种注意力机制的组合,既保证全局信息,又保留关键 Token 和局部上下文信息。
- 最终的注意力结果通过可学习的门控机制(Gating Mechanism)进行融合,实现动态调整。
- 计算复杂度降低到 O ( N log K ) O(N \log K) O(NlogK),适用于超长文本(64k+)。
代码是AI生成的 还在调试中