Deepseek Natively Sparse Attention

NSA(Natively Sparse Attention)论文原理解析

论文标题: Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
作者团队: DeepSeek-AI, Peking University, University of Washington
核心目标: 提出一种高效、可训练的稀疏注意力机制,以提高长文本处理的计算效率,同时保持模型性能。


1. NSA 研究背景

1.1 长文本建模的挑战

  • 现代大模型(如 GPT-4, Gemini 1.5)需要处理超长文本(64k 甚至更长)。
  • 传统 全注意力(Full Attention) 计算复杂度为 O(N²),在长文本上计算开销巨大,导致 训练和推理效率低下
  • 现有的稀疏注意力(Sparse Attention)方法在 训练阶段支持较弱,通常只优化推理阶段。

1.2 现有稀疏注意力方法的局限

  • 理论计算减少 ≠ 真实速度提升
    • 许多方法仅在 推理阶段(Inference Stage) 优化,忽略了 训练时间的计算成本
    • 例如 H2O、Quest 关注 KV 缓存剪枝(KV-cache pruning),但 在真实硬件上加速有限
  • 训练阶段支持不足
    • 许多方法采用 离散选择(Discrete Selection),导致 梯度无法回传,难以进行端到端训练(End-to-End Training)。

2. NSA 方法:基于层次化的稀疏注意力

NSA 提出的创新点:

  1. 层次化稀疏策略(Hierarchical Sparse Strategy)
    • 结合 粗粒度 token 压缩(Compression)细粒度 token 选择(Selection),同时保留 全局信息局部精度
  2. 硬件优化(Hardware-Aligned System)
    • 设计 适用于现代 GPU(如 A100, H100)的优化算子,提升推理效率。
  3. 可训练性增强(Natively Trainable Design)
    • 允许在 训练阶段 进行稀疏优化,而不仅仅是在推理阶段加速。

2.1 NSA 关键机制

NSA 通过 三种注意力路径 进行计算:

  1. 压缩注意力(Compressed Attention)
    • 通过块级 Token 压缩(Blockwise Token Compression),减少计算开销。
  2. 选择性注意力(Selected Attention)
    • 仅保留 Top-k 重要 token,忽略不重要的 Token,提高计算效率。
  3. 滑动窗口注意力(Sliding Attention)
    • 确保局部上下文不会丢失,提高信息完整性。
NSA 计算过程
  1. 查询(Query) 经过 三种注意力路径 计算 注意力得分(Attention Score)
  2. 不同路径的注意力结果通过门控机制(Gating Mechanism)进行加权融合
  3. 最终得到优化后的注意力输出(Sparse Attention Output)

3. NSA 在硬件上的优化

3.1 计算强度均衡(Arithmetic Intensity Balance)

  • 在现代 GPU 上,计算强度(Arithmetic Intensity)决定了性能瓶颈:
    • 高计算强度(Compute-Bound):计算单元占用率高,计算能力未完全发挥。
    • 低计算强度(Memory-Bound):计算单元空闲,受限于显存访问速度。
  • NSA 通过 块级计算(Blockwise Computation) 提高 计算密度(Compute Density),减少显存访问瓶颈。

3.2 Triton 自定义内核(Triton Kernel Optimization)

  • 传统注意力计算 内存访问不连续,GPU 计算利用率低
  • NSA 通过 基于 Triton 的自定义 GPU 内核(Custom GPU Kernel for Sparse Selection)
    • 组级数据加载(Group-Centric Data Loading):避免多次访问 KV 缓存,减少内存带宽压力。
    • 共享 KV 读取(Shared KV Fetching):减少重复数据加载,提高计算效率。

4. NSA 在实验中的表现

4.1 计算加速

  • 相比全注意力(Full Attention),NSA 在 64k 序列上的速度提升最高可达 11.6×
  • 在训练阶段,NSA 前向传播(Forward)速度提高 9.0×,反向传播(Backward)速度提高 6.0×

4.2 模型性能

  • 在多个 自然语言任务(NLP Benchmarks) 上,NSA 在 保持甚至超过全注意力性能 的同时,大幅提高计算效率。
  • 64k 长文本任务(LongBench Benchmark)中,NSA 超过所有现有稀疏注意力方法

4.3 复杂推理能力

  • NSA 在 数学推理任务(AIME 24 Benchmark) 中表现出色:
    • 8k 和 16k 上下文长度下,NSA 比全注意力基线提高 2.5× 和 1.6×

5. NSA 的关键优势

特点NSA 贡献
计算复杂度降低通过 层次化稀疏选择,将 O(N²) 降至 O(N log K)
硬件优化适配 GPU Tensor Cores,优化内存访问,提高计算效率。
训练支持NSA 可训练(Natively Trainable),不同于只优化推理的稀疏方法。
长文本处理能力64k 长文本任务上超越全注意力,同时加速 推理和训练

6. 论文总结

NSA 通过 层次化稀疏注意力、硬件优化、训练可行性,在 计算加速和性能保持之间取得了平衡
相较于现有方法,NSA 不仅优化了推理(Inference),还显著降低了训练(Training)计算成本,为长文本建模提供了新的解决方案。


压缩注意力(Compressed Attention)机制解析

目标

  • 在保持全局信息的同时 降低计算复杂度,减少 Query-Key 计算量。
  • 通过 块级(blockwise)token 聚合,减少注意力计算中需要处理的 Key-Value 数量。

1. 为什么需要压缩注意力?

  • 标准注意力机制:每个 Query q q q 需要计算所有 Key K K K 的注意力分数,计算复杂度为 O ( N 2 ) O(N^2) O(N2)
  • 稀疏注意力(Sparse Attention):可以减少部分 Query-Key 计算,但仍然面临计算量和显存占用的问题。
  • 压缩注意力(Compressed Attention) 通过 对 Key-Value 进行块级聚合,减少 Key-Value 数量,降低计算复杂度。

2. 压缩注意力的具体方法

NSA 采用 块级 token 聚合 的方式,将 Key-Value 压缩成更少的代表性 token
这一过程可以分为 四步

2.1. 按块划分 Key-Value

  • 输入序列长度为 T T T,Key-Value 维度为 d k d_k dk(Key 维度)和 d v d_v dv(Value 维度)。
  • 选择 块大小(block size) l l l,把 Key-Value 分成多个块:
    • i i i 块的 Key 表示为:

K i = { k i ⋅ l , k i ⋅ l + 1 , … , k ( i + 1 ) ⋅ l − 1 } K_i = \{ k_{i \cdot l}, k_{i \cdot l+1}, \dots, k_{(i+1) \cdot l - 1} \} Ki={kil,kil+1,,k(i+1)l1}
- 第 i i i 块的 Value 表示为:

V i = { v i ⋅ l , v i ⋅ l + 1 , … , v ( i + 1 ) ⋅ l − 1 } V_i = \{ v_{i \cdot l}, v_{i \cdot l+1}, \dots, v_{(i+1) \cdot l - 1} \} Vi={vil,vil+1,,v(i+1)l1}
- 这样,原始 Key-Value 变成了 块级 Key-Value,大幅减少了 Key 的数量。

2.2. 计算块级 Key 的代表性

  • 块级 Key K cmp K_{\text{cmp}} Kcmp 需要能够代表整个块的信息,可以用 平均池化(Mean Pooling)可训练 MLP
    • 平均池化(Mean Pooling):

K cmp , i = 1 l ∑ j = 0 l − 1 K i ⋅ l + j K_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} K_{i \cdot l + j} Kcmp,i=l1j=0l1Kil+j
- 可训练 MLP(Multi-Layer Perceptron):

K cmp , i = MLP ( K i ⋅ l : ( i + 1 ) ⋅ l ) K_{\text{cmp}, i} = \text{MLP}(K_{i \cdot l : (i+1) \cdot l}) Kcmp,i=MLP(Kil:(i+1)l)
- 其中 MLP 可以学习更丰富的特征,而平均池化计算量更低。

2.3. 计算块级 Value

  • 块级 Value V cmp V_{\text{cmp}} Vcmp 也可以采用类似方法:
    • 平均池化:

V cmp , i = 1 l ∑ j = 0 l − 1 V i ⋅ l + j V_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} V_{i \cdot l + j} Vcmp,i=l1j=0l1Vil+j
- 或使用 MLP:

V cmp , i = MLP ( V i ⋅ l : ( i + 1 ) ⋅ l ) V_{\text{cmp}, i} = \text{MLP}(V_{i \cdot l : (i+1) \cdot l}) Vcmp,i=MLP(Vil:(i+1)l)
- 这样可以降低计算量,同时保留重要信息。

2.4. 使用压缩 Key-Value 计算注意力

  • 计算 Query Q Q Q 和压缩后的 Key K cmp K_{\text{cmp}} Kcmp 之间的注意力

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dk QKcmpT

  • 计算 Softmax:

A cmp ′ = Softmax ( A cmp ) A'_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) Acmp=Softmax(Acmp)

  • 计算最终的注意力输出:

O cmp = A cmp ′ V cmp O_{\text{cmp}} = A'_{\text{cmp}} V_{\text{cmp}} Ocmp=AcmpVcmp


3. 压缩注意力的优势

对比项普通注意力稀疏注意力(Sparse Attention)压缩注意力(Compressed Attention)
计算复杂度 O ( N 2 ) O(N^2) O(N2) O ( N log ⁡ k ) O(N \log k) O(Nlogk) O ( N ⋅ M ) O(N \cdot M) O(NM)(( M \ll N \))
信息保留完整信息仅保留 Top-k 信息保留全局信息,同时减少计算量
适用场景短文本长文本,但计算仍然较大适合超长文本(64k+),计算高效
  • 相比全注意力(Full Attention),压缩注意力减少了计算量。
  • 相比其他稀疏注意力方法,压缩注意力能保留更多全局信息,同时具有更好的计算效率。

4. 代码示例

这里是一个 PyTorch 实现的 压缩注意力

import torch
import torch.nn as nnclass CompressedAttention(nn.Module):def __init__(self, embed_dim, num_heads, block_size=32):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.block_size = block_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.size()  # Batch, Sequence Length, Embedding Dimension# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Blockwise compression (mean pooling)num_blocks = T // self.block_sizeK_cmp = K.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)V_cmp = V.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)# Compute attention with compressed keysattn_weights = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_weights = torch.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, V_cmp)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
block_size = 16attention = CompressedAttention(C, num_heads, block_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 总结

  • 压缩注意力(Compressed Attention) 通过 块级聚合 Key-Value,大幅降低计算量,同时保留全局信息。
  • 计算复杂度降低 O ( N 2 ) → O ( N ⋅ M ) O(N^2) \to O(N \cdot M) O(N2)O(NM),其中 M ≪ N M \ll N MN(压缩后的块数)。
  • 适用于超长文本建模,在 64k 甚至更长的序列 上能够高效工作。
  • 硬件友好,支持 GPU Tensor Core 优化,减少显存占用。

选择性注意力(Selected Attention)机制解析

目标

  • 选择 最重要的 Key-Value 进行计算,而不是对所有 Key 计算注意力,从而降低计算复杂度。
  • 通过 Top-K 选择策略,保留最关键的信息,减少冗余计算,提高长序列建模能力。

1. 为什么需要选择性注意力?

  • 普通注意力(Full Attention) 计算复杂度 O(N²),当序列长度很长(如 64k+),计算量巨大。
  • 压缩注意力(Compressed Attention) 通过 块级聚合 降低计算量,但可能损失部分细节信息。
  • 选择性注意力(Selected Attention) 进一步优化,只保留最重要的 Token 参与计算,避免处理不重要的信息,减少计算开销,同时保持全局和局部信息。

2. 选择性注意力的核心步骤

NSA 采用 基于注意力得分的动态 Top-K 选择(Top-K Token Selection) 方法来筛选关键 Token:

2.1. 计算 Query-Key 相关性

首先,计算 查询(Query)所有键(Key) 的相似性(即注意力分数):

A = Q K T d k A = \frac{Q K^T}{\sqrt{d_k}} A=dk QKT

其中:

  • A A A 是注意力分数矩阵,形状为 ( B , H , T , T ) (B, H, T, T) (B,H,T,T),表示每个 Query 对应 Key 的注意力得分。

2.2. 选择 Top-K 重要 Token

  • 对于每个 Query,选择 Top-K 重要的 Key,其余的 Key 设为 − ∞ -\infty (即被 Mask)。
  • 具体实现:
    • 计算每个 Query 对所有 Key 的注意力分数。
    • 使用 Top-K 算法 找出最大的 K K K 个值,索引存入 I top-k I_{\text{top-k}} Itop-k

I top-k = argtopk ( A , K ) I_{\text{top-k}} = \text{argtopk}(A, K) Itop-k=argtopk(A,K)
- 构造稀疏化的注意力分数矩阵:

A i j ′ = { A i j , j ∈ I top-k ( i ) − ∞ , 否则 A'_{ij} = \begin{cases} A_{ij}, & j \in I_{\text{top-k}}(i) \\ -\infty, & \text{否则} \end{cases} Aij={Aij,,jItop-k(i)否则
- 这样,我们 只在最重要的 Top-K Token 上计算 Softmax

A ~ = Softmax ( A ′ ) \tilde{A} = \text{Softmax}(A') A~=Softmax(A)

2.3. 计算注意力输出

最终,用选择的 Top-K 注意力分数 计算新的 Value 权重求和

O = A ~ V O = \tilde{A} V O=A~V

这样,Query 只会与 最相关的 Key-Value 交互,提高计算效率,同时保留重要信息。


3. 选择性注意力的优势

方法计算复杂度信息保留能力适用场景
全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2)完整适用于短文本
压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(NM)保留全局信息适用于长文本
选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(NK)只保留最重要信息适用于超长文本(64k+)
  • 相比全注意力(Full Attention),选择性注意力只计算 Top-K 重要信息,大幅降低计算量。
  • 相比压缩注意力(Compressed Attention),选择性注意力能保留 更精确的局部信息,保证高精度。

4. PyTorch 实现

以下是 选择性注意力 的 PyTorch 代码:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelectedAttention(nn.Module):def __init__(self, embed_dim, num_heads, top_k):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.top_k = top_kself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, d_k)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Compute attention scoresattn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, H, T, T)# Select Top-K tokenstopk_values, topk_indices = torch.topk(attn_scores, self.top_k, dim=-1)  # (B, H, T, K)# Create a mask for non-Top-K elementsmask = torch.full_like(attn_scores, float('-inf'))  # Default maskmask.scatter_(-1, topk_indices, topk_values)  # Retain Top-K values# Apply softmax on selected tokensattn_weights = F.softmax(mask, dim=-1)# Compute attention outputattn_output = torch.matmul(attn_weights, V)  # (B, H, T, d_k)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
top_k = 16  # 选择 Top-K 重要 Tokenattention = SelectedAttention(C, num_heads, top_k)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 选择性注意力的优化方向

  1. Top-K 选择的优化
    • 目前使用 torch.topk() 进行选择,时间复杂度为 O ( N log ⁡ K ) O(N \log K) O(NlogK)
    • 可以优化为 Heap Sort + 近似选择算法,进一步提高效率。
  2. 自适应 K 值选择
    • 目前的 K 值是固定的,可以使用 Learnable Gate 机制,让模型 动态决定 K 的大小
  3. 结合其他稀疏注意力
    • 压缩注意力 + 选择性注意力 可以同时 减少计算量保留最关键信息,适合超长序列任务(64k+)。

6. 总结

  • 选择性注意力(Selected Attention) 通过 Top-K 选择 只保留最重要的 Key-Value,降低计算量。
  • 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ K ) O(N \cdot K) O(NK),适用于 超长文本(64k+)
  • 相较于全注意力和压缩注意力,选择性注意力能更精准地保留信息,同时减少计算成本。
  • 可进一步优化 通过 更快的 Top-K 选择算法自适应 K 值选择 提升性能。

滑动窗口注意力(Sliding Attention)机制解析

目标

  • 在减少计算量的同时,保留局部上下文信息,确保模型能够感知短期依赖关系
  • 结合 压缩注意力(Compressed Attention)选择性注意力(Selected Attention),在局部窗口范围内保留完整的注意力计算,避免远程信息丢失。

1. 为什么需要滑动窗口注意力?

  • 全注意力(Full Attention) 计算复杂度 O(N²),长序列(64k+)下计算成本极高。
  • 压缩注意力(Compressed Attention) 关注全局信息,但可能会丢失局部细节。
  • 选择性注意力(Selected Attention) 关注最关键的信息,但可能无法保留局部语境。
  • 滑动窗口注意力(Sliding Attention) 通过局部窗口机制,确保模型可以关注最近的信息,同时减少计算量。

2. 滑动窗口注意力的核心步骤

NSA 采用 基于局部窗口的注意力计算(Local Context Attention),主要分为 四步

2.1. 定义窗口范围

  • 设序列长度为 T T T,窗口大小设定为 W W W(window size),则对于每个 Query Q i Q_i Qi,它只会计算:

K win , i = { k i − W , k i − W + 1 , … , k i } K_{\text{win}, i} = \{ k_{i-W}, k_{i-W+1}, \dots, k_i \} Kwin,i={kiW,kiW+1,,ki}

V win , i = { v i − W , v i − W + 1 , … , v i } V_{\text{win}, i} = \{ v_{i-W}, v_{i-W+1}, \dots, v_i \} Vwin,i={viW,viW+1,,vi}
- 窗口只包含最近 W W W 个 Token,降低计算复杂度。
- 可变窗口机制:可根据任务需求设定不同的窗口大小(例如代码生成任务可能需要更大的窗口)。

2.2. 计算窗口内的 Query-Key 注意力

  • 在窗口范围 W W W 内计算标准注意力:

A win , i = Q i K win , i T d k A_{\text{win}, i} = \frac{Q_i K_{\text{win}, i}^T}{\sqrt{d_k}} Awin,i=dk QiKwin,iT
- 相比于全局注意力(O(N²)),窗口内计算量为 O(N × W),显著降低复杂度。
- 仅关注 最近 W W W 个 Token,保证短期依赖关系。

2.3. 计算 Softmax 并加权求和

  • 计算窗口内的注意力分布:

A win , i ′ = Softmax ( A win , i ) A'_{\text{win}, i} = \text{Softmax}(A_{\text{win}, i}) Awin,i=Softmax(Awin,i)

  • 计算最终的注意力输出:

O win , i = A win , i ′ V win , i O_{\text{win}, i} = A'_{\text{win}, i} V_{\text{win}, i} Owin,i=Awin,iVwin,i

2.4. 结合其他注意力机制

  • 最终输出:

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin
- g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是可学习的门控参数(Gating Mechanism)。
- 这样可以在训练过程中,让模型学习最佳的注意力组合方式


3. 滑动窗口注意力的优势

方法计算复杂度局部信息保留适用场景
全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2)✅ 完整适用于短文本
压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(NM)⚠️ 可能丢失局部信息适用于长文本
选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(NK)⚠️ 仅保留关键 Token适用于超长文本
滑动窗口注意力(Sliding Attention) O ( N ⋅ W ) O(N \cdot W) O(NW)✅ 重点保留局部信息适用于超长文本(64k+)
  • 相比全注意力(Full Attention),滑动窗口注意力 显著减少计算量
  • 相比压缩注意力(Compressed Attention),滑动窗口注意力确保局部信息不会丢失
  • 相比选择性注意力(Selected Attention),滑动窗口注意力不会忽略短期依赖

4. PyTorch 实现

以下是 滑动窗口注意力(Sliding Attention) 的 PyTorch 代码:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SlidingWindowAttention(nn.Module):def __init__(self, embed_dim, num_heads, window_size):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.window_size = window_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, d_k)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# Initialize attention scores (masked)attn_scores = torch.full((B, self.num_heads, T, T), float('-inf'), device=query.device)# Apply sliding window maskfor i in range(T):start_idx = max(0, i - self.window_size)attn_scores[:, :, i, start_idx:i+1] = torch.matmul(Q[:, :, i:i+1, :], K[:, :, start_idx:i+1, :].transpose(-2, -1)) / (self.head_dim ** 0.5)# Compute attention with masked softmaxattn_weights = F.softmax(attn_scores, dim=-1)# Compute attention outputattn_output = torch.matmul(attn_weights, V)  # (B, H, T, d_k)# Reshape and outputattn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(attn_output)# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
window_size = 16  # 滑动窗口大小attention = SlidingWindowAttention(C, num_heads, window_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 进一步优化方向

  1. 动态窗口大小
    • 当前的 窗口大小 W W W 是固定的,可以使用 自适应机制(Adaptive Window Size) 让模型学习最佳的窗口大小。
  2. 结合 FlashAttention 提高计算效率
    • 目前的 滑动窗口计算仍然需要遍历 Query,可以优化成 块级计算(Blockwise Computation),提升 GPU 利用率。

6. 总结

  • 滑动窗口注意力(Sliding Attention) 通过 局部窗口计算,减少计算量,同时保留最近的上下文信息。
  • 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ W ) O(N \cdot W) O(NW),适用于 超长文本(64k+)
  • 结合其他注意力机制(压缩 + 选择性 + 滑动窗口)可以 提高计算效率,同时保留全局 + 局部信息

NSA 论文中如何结合三种注意力机制?

Natively Sparse Attention(NSA) 机制中,作者采用了一种 层次化稀疏注意力策略(Hierarchical Sparse Strategy),将 压缩注意力(Compressed Attention)、选择性注意力(Selected Attention)和滑动窗口注意力(Sliding Attention) 结合,以 同时保留全局信息、关键 Token 以及局部信息,提高计算效率并优化长序列建模。


1. NSA 采用的三条注意力路径

NSA 通过以下三种不同的注意力计算路径,让 Transformer 既能高效处理长序列,又不会丢失关键信息

  1. 压缩注意力(Compressed Attention)
    • 作用:全局信息提取
    • 方式:将 Key-Value 进行 块级压缩,生成粗粒度的全局 Token 表示。
    • 计算复杂度: O ( N ⋅ M ) O(N \cdot M) O(NM)(其中 M ≪ N M \ll N MN)。
  2. 选择性注意力(Selected Attention)
    • 作用:筛选最关键的 Token 进行计算
    • 方式:对所有 Query 计算注意力分数,并选择 Top-K 重要 Token,仅对这些 Key 计算注意力。
    • 计算复杂度: O ( N ⋅ K ) O(N \cdot K) O(NK)(其中 K ≪ N K \ll N KN)。
  3. 滑动窗口注意力(Sliding Attention)
    • 作用:局部上下文信息保留
    • 方式:每个 Query 仅在其 最近的 W W W 个 Token 内 计算注意力,保留短期依赖信息。
    • 计算复杂度: O ( N ⋅ W ) O(N \cdot W) O(NW)(其中 W ≪ N W \ll N WN)。

最终的注意力输出是三种机制的加权和

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin

其中:

  • g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin可学习的门控参数(Gating Mechanism),用于控制不同注意力机制的重要性。

2. NSA 具体如何组合这三种注意力?

(1) 计算 Query-Key 相关性

首先,对 Query 计算三种不同 Key 形式的注意力分数:

  1. 压缩 Key( K cmp K_{\text{cmp}} Kcmp:计算 Query 和 压缩后的 Key 的相关性:

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dk QKcmpT
2. 选择性 Key( K sel K_{\text{sel}} Ksel:计算 Query 和 Top-K 选择的 Key 的相关性:

A sel = Q K sel T d k A_{\text{sel}} = \frac{Q K_{\text{sel}}^T}{\sqrt{d_k}} Asel=dk QKselT
3. 滑动窗口 Key( K win K_{\text{win}} Kwin:计算 Query 在 局部窗口范围内 的注意力:

A win = Q K win T d k A_{\text{win}} = \frac{Q K_{\text{win}}^T}{\sqrt{d_k}} Awin=dk QKwinT

(2) 计算 Softmax 归一化

对每个注意力分数进行 Softmax 计算:

A ~ cmp = Softmax ( A cmp ) \tilde{A}_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) A~cmp=Softmax(Acmp)

A ~ sel = Softmax ( A sel ) \tilde{A}_{\text{sel}} = \text{Softmax}(A_{\text{sel}}) A~sel=Softmax(Asel)

A ~ win = Softmax ( A win ) \tilde{A}_{\text{win}} = \text{Softmax}(A_{\text{win}}) A~win=Softmax(Awin)

(3) 计算注意力输出

计算不同注意力的加权求和:

O cmp = A ~ cmp V cmp O_{\text{cmp}} = \tilde{A}_{\text{cmp}} V_{\text{cmp}} Ocmp=A~cmpVcmp

O sel = A ~ sel V sel O_{\text{sel}} = \tilde{A}_{\text{sel}} V_{\text{sel}} Osel=A~selVsel

O win = A ~ win V win O_{\text{win}} = \tilde{A}_{\text{win}} V_{\text{win}} Owin=A~winVwin

(4) 加权融合不同注意力结果

最终的输出由三种注意力结果加权融合

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin

其中:

  • g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin可学习的门控参数,通过 MLP 计算:

g = σ ( MLP ( X ) ) g = \sigma(\text{MLP}(X)) g=σ(MLP(X))

其中 σ \sigma σ 是 Sigmoid 激活函数,确保 g g g 取值在 (0,1) 之间。


3. PyTorch 实现

以下是 结合三种注意力的 NSA 模型

import torch
import torch.nn as nn
import torch.nn.functional as Fclass NSA(nn.Module):def __init__(self, embed_dim, num_heads, top_k, window_size):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.top_k = top_kself.window_size = window_sizeself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)self.gate_mlp = nn.Sequential(nn.Linear(embed_dim, 3), nn.Sigmoid())  # 生成3个门控权重def forward(self, query, key, value):B, T, C = query.shape# ProjectionQ = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# 计算门控权重gate_weights = self.gate_mlp(query).unsqueeze(-1).unsqueeze(-1)  # (B, T, 3) -> (B, T, 3, 1, 1)# 压缩注意力K_cmp = K.mean(dim=-2, keepdim=True)V_cmp = V.mean(dim=-2, keepdim=True)attn_cmp = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)O_cmp = torch.matmul(F.softmax(attn_cmp, dim=-1), V_cmp)# 选择性注意力topk_values, topk_indices = torch.topk(attn_cmp, self.top_k, dim=-1)attn_sel = torch.zeros_like(attn_cmp).scatter_(-1, topk_indices, topk_values)O_sel = torch.matmul(F.softmax(attn_sel, dim=-1), V)# 滑动窗口注意力attn_win = attn_cmp.masked_fill(torch.arange(T)[:, None] < (torch.arange(T) - self.window_size), float('-inf'))O_win = torch.matmul(F.softmax(attn_win, dim=-1), V)# 加权求和O = gate_weights[..., 0] * O_cmp + gate_weights[..., 1] * O_sel + gate_weights[..., 2] * O_winO = O.transpose(1, 2).contiguous().view(B, T, C)return self.out_proj(O)# 测试
B, T, C = 2, 64, 128
attention = NSA(C, num_heads=8, top_k=16, window_size=16)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)
output = attention(query, key, value)
print(output.shape)  # (B, T, C)

总结

  • NSA 通过三种注意力机制的组合,既保证全局信息,又保留关键 Token 和局部上下文信息。
  • 最终的注意力结果通过可学习的门控机制(Gating Mechanism)进行融合,实现动态调整。
  • 计算复杂度降低到 O ( N log ⁡ K ) O(N \log K) O(NlogK),适用于超长文本(64k+)。

代码是AI生成的 还在调试中

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/22203.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2.21力扣-回溯组合

77. 组合 - 力扣&#xff08;LeetCode&#xff09; 一&#xff1a;JAVA class Solution {List<Integer> list new LinkedList<>();List<List<Integer>> ans new LinkedList<>();public List<List<Integer>> combine(int n, int k)…

智能合约的部署

https://blog.csdn.net/qq_40261606/article/details/123249473 编译 点击图中的 “Compile 1_Storage.sol” 存和取一个数的合约&#xff0c;remix自带 pragma solidity >0.8.2 <0.9.0; /*** title Storage* dev Store & retrieve value in a variable* custom:d…

vmvare kali如何配置桥接模式进行上网

注意点:虚拟机可以PING通物理机,但是PING不通其他的网站。经过收集资料,得知由于是校园网连接,所以DHCP只能分配一个授权的IP地址给连接的主机,由于KALI是桥接物理机,物理机已经获得了这个授权的IP,所以导致桥接的虚拟机无法上网。所以不是因为配置的有问题,而是网络的…

了解Python中的SciPy库

么是 SciPy&#xff1f; SciPy&#xff08;发音为“Sigh Pie”&#xff09;是 Scientific Python 的首字母缩写词&#xff0c;它是 Python 的开源库&#xff0c;用于科学和技术计算。它是 Python 编程语言中称为 Numpy 的基本数组处理库的扩展&#xff0c;旨在支持高级科学和工…

python网络安全怎么学 python做网络安全

&#x1f345; 点击文末小卡片 &#xff0c;免费获取网络安全全套资料&#xff0c;资料在手&#xff0c;涨薪更快 众所周知&#xff0c;python是近几年比较火的语言之一&#xff0c;它具有简单易懂、免费开源、可移植、可扩展、丰富的第三方库函数等特点&#xff0c;Java需要大…

Ubuntu下mysql主从复制搭建

本文介绍mysql 8.4主从集群的搭建&#xff0c;从单个机器安装到集群的配置&#xff0c;整体走了一遍&#xff0c;希望对大家有帮助。mysql 8.4和之前的版本命令上有些变化&#xff0c;大家用来参考。 0、环境 ubuntu&#xff1a; 22.04mysql&#xff1a;8.4 1、安装mysql 1…

MAC快速本地部署Deepseek (win也可以)

MAC快速本地部署Deepseek (win也可以) 下载安装ollama 地址: https://ollama.com/ Ollama 是一个开源的大型语言模型&#xff08;LLM&#xff09;本地运行框架&#xff0c;旨在简化大模型的部署和管理流程&#xff0c;使开发者、研究人员及爱好者能够高效地在本地环境中实验和…

Spring Boot框架总结(超级详细)

前言 本篇文章包含Springboot配置文件解释、热部署、自动装配原理源码级剖析、内嵌tomcat源码级剖析、缓存深入、多环境部署等等&#xff0c;如果能耐心看完&#xff0c;想必会有不少收获。 一、Spring Boot基础应用 Spring Boot特征 概念&#xff1a; 约定优于配置&#…

易基因: ChIP-seq+DRIP-seq揭示AMPK通过调控H3K4me3沉积和R-loop形成以维持基因组稳定性和生殖细胞完整性|NAR

原文&#xff1a;ChIP-seqDRIP-seq揭示AMPK通过调控H3K4me3沉积和R-loop形成以维持基因组稳定性和生殖细胞完整性&#xff5c;NAR 大家好&#xff0c;这里是专注表观组学十余年&#xff0c;领跑多组学科研服务的易基因。 在饥饿等能量胁迫条件下&#xff0c;生物体会通过调整…

uniapp h5端和app端 使用 turn.js

前提:添加页后,添加页与当前页会重叠在一起,不知道为什么,没有找到解决办法 1.h5端 <template><view class"container"><view id"flipbook"><view class"page page1">Page 1</view><view class"page pag…

MySQL数据库(3)—— 表操作

目录 一&#xff0c;创建表 1.1 创建表的SQL 1.2 演示 二&#xff0c;查看表 三&#xff0c;修改表 四&#xff0c;删除表 常用的表操作会涉及到两种SWL语句 DDL&#xff08;Data Definition Language&#xff09;数据定义语言&#xff1a;建表、改表、删表等&#xff0…

【精调】LLaMA-Factory 快速开始4 自定义个一个sharegpt数据集并训练

数据格式说明 LLaMA Factory:微调LLaMA3模型实现角色扮演 数据集 参考 开源模型应用落地-DeepSeek-R1-Distill-Qwen-7B-LoRA微调-LLaMA-Factory-单机单卡-V100(一) 大神给出的数据集的讲解:注册 如

Unity 位图字体

下载Bitmap Font Generator BMFont - AngelCode.com 解压后不用安装直接双击使用 提前设置 1、设置Bit depth为32 Options->Export options 2、清空所选字符 因为我们将在后边导入需要的字符。 Edit->Select all chars 先选择所有字符 Edit->Clear all chars i…

open webui 部署 以及解决,首屏加载缓慢,nginx反向代理访问404,WebSocket后端服务器链接失败等问题

项目地址&#xff1a;GitHub - open-webui/open-webui: User-friendly AI Interface (Supports Ollama, OpenAI API, ...) 选择了docker部署 如果 Ollama 在您的计算机上&#xff0c;请使用以下命令 docker run -d -p 3000:8080 --add-hosthost.docker.internal:host-gatewa…

Servlet概述(Ⅰ)

目录 一、Servlet概述 演示 创建JavaWeb项目&#xff08;2017版本为例&#xff09; 1. 打开 IntelliJ IDEA 2. 选择项目类型 3. 配置框架 二、Servlet初识(熟练) 1.servlet说明 2.Servlet 接口方法 3.创建Servlet 4.JavaWeb请求响应流程 ​编辑 ​编辑 5.servlet…

Spring Cloud — Hystrix 服务隔离、请求缓存及合并

Hystrix 的核心是提供服务容错保护&#xff0c;防止任何单一依赖耗尽整个容器的全部用户线程。使用舱壁隔离模式&#xff0c;对资源或失败单元进行隔离&#xff0c;避免一个服务的失效导致整个系统垮掉&#xff08;雪崩效应&#xff09;。 1 Hystrix监控 Hystrix 提供了对服务…

DeepSeek 助力 Vue 开发:打造丝滑的 键盘快捷键(Keyboard Shortcuts)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

WPS接入deepseek-OfficeAI助手插件下载

功能简介 OfficeAI 助手 是一款免费的智能AI办公工具软件&#xff0c;专为 Microsoft Office 和 WPS 用户打造。 无论你是在寻找如何输入“打勾&#xff08;√&#xff09;符号”的方法&#xff0c;还是想知道“怎么在插入表格前添加文字”&#xff0c;或者“该用哪个公式”&a…

关系数据理论

一、函数依赖 若t1(X)t2(X),必有t1(Y)t2(Y),那么我们称属性组X函数确定属性组Y&#xff0c;或者说Y函数依赖于X。记为X->Y&#xff0c;其中X叫决定因素&#xff0c;Y叫依赖因素。 平凡函数依赖与非平凡函数依赖&#xff1a; 二、1-BCNF 评价关系模式“好坏”的理论标准就…

【C】队列与栈的相互转换

栈与队列是两种特点相反的数据结构&#xff0c;一个特点是后进先出&#xff0c;一个特点是先进先出&#xff0c;但是他们之间是可以相互转换的。 目录 1 用队列实现栈 1&#xff09; 题目解析 2&#xff09; 算法解析 &#xff08;1&#xff09; 结构(MyStack) &#xff…