在大语言模型(LLM)训练过程中,Masked Attention(掩码注意力) 是一个关键机制,它决定了 模型如何在训练时只利用过去的信息,而不会看到未来的 token。这篇文章将帮助你理解 Masked Attention 的作用、实现方式,以及为什么它能确保当前 token 只依赖于过去的 token,而不会泄露未来的信息。
1. Masked Attention 在 LLM 训练中的作用
在 LLM 训练时,我们通常使用 自回归(Autoregressive) 方式来让模型学习文本的生成。例如,给定输入序列:
"The cat is very"
模型需要预测下一个 token:
"cute"
但是,为了保证模型的生成方式符合自然语言流向,每个 token 只能看到它之前的 token,不能看到未来的 token。
Masked Attention 的作用就是:
- 屏蔽未来的 token,使当前 token 只能关注之前的 token
- 保证训练阶段的注意力机制符合推理时的因果(causal)生成方式
- 防止信息泄露,让模型学会自回归生成文本
如果没有 Masked Attention,模型在训练时可以“偷看”未来的 token,导致它学到的规律无法泛化到推理阶段,从而影响文本生成的效果。
举例说明
假设输入是 "The cat is cute",模型按 token 级别计算注意力:
(1) 没有 Mask(BERT 方式)
Token | The | cat | is | cute |
---|---|---|---|---|
The | ✅ | ✅ | ✅ | ✅ |
cat | ✅ | ✅ | ✅ | ✅ |
is | ✅ | ✅ | ✅ | ✅ |
cute | ✅ | ✅ | ✅ | ✅ |
每个 token 都能看到整个句子,适用于 BERT 这种双向模型。
(2) 有 Mask(GPT 方式)
Token | The | cat | is | cute |
---|---|---|---|---|
The | ✅ | ❌ | ❌ | ❌ |
cat | ✅ | ✅ | ❌ | ❌ |
is | ✅ | ✅ | ✅ | ❌ |
cute | ✅ | ✅ | ✅ | ✅ |
每个 token 只能看到它自己及之前的 token,保证训练和推理时的生成顺序一致。
2. Masked Attention 的工作原理
在标准的 自注意力(Self-Attention) 机制中,注意力分数是这样计算的:
其中:
-
Q, K, V 是 Query(查询)、Key(键)和 Value(值)矩阵
-
计算所有 token 之间的相似度
-
如果不做 Masking,每个 token 都能看到所有的 token
而在 Masked Attention 中,我们会使用一个 上三角掩码(Upper Triangular Mask),使得未来的 token 不能影响当前 token:
Mask 是一个 上三角矩阵,其中:
-
未来 token 的位置填充
,确保 softmax 之后它们的注意力权重为 0
-
只允许关注当前 token 及之前的 token
例如,假设有 4 个 token:
经过 softmax 之后:
最终,每个 token 只会关注它自己和它之前的 token,完全忽略未来的 token!
3. Masked Attention 计算下三角部分的值时,如何保证未来信息不会泄露?
换句话说,我们需要证明 Masked Attention 计算出的下三角部分的值(即历史 token 之间的注意力分数)不会受到未来 token 的影响。
1. 问题重述
Masked Attention 的核心计算是:
其中:
-
Q, K, V 是整个序列的矩阵。
-
计算的是所有 token 之间的注意力分数。
-
Mask 确保 softmax 后未来 token 的注意力分数变为 0。
这个问题可以分解成两个关键点:
-
未来 token 是否影响了下三角部分的 Q 或 K?
-
即使未来 token 参与了 Q, K 计算,为什么它们不会影响下三角的注意力分数?
2. 未来 token 是否影响了 Q 或 K?
我们先看 Transformer 计算 Q, K, V 的方式:
这里:
-
X 是整个输入序列的表示。
-
是相同的投影矩阵,作用于所有 token。
由于 每个 token 的 Q, K, V 只取决于它自己,并不会在计算时使用未来 token 的信息,所以:
-
计算第 i 个 token 的
时,并没有用到
,所以未来 token 并不会影响当前 token 的 Q, K, V。
结论 1:未来 token 不会影响当前 token 的 Q 和 K。
3. Masked Attention 如何确保下三角部分不包含未来信息?
即使 Q, K 没有未来信息,我们仍然要证明 计算出的注意力分数不会受到未来信息影响。
我们来看注意力计算:
这是一个 所有 token 之间的相似度矩阵,即:
然后,我们应用 因果 Mask(Causal Mask):
Mask 让右上角(未来 token 相关的部分)变成 :
然后计算 softmax:
由于 ,所有未来 token 相关的注意力分数都变成 0:
最后,我们计算:
由于未来 token 的注意力权重是 0,它们的 V 在计算中被忽略。因此,下三角部分(历史 token 之间的注意力)完全不受未来 token 影响。
结论 2:未来 token 的信息不会影响下三角部分的 Attention 计算。
4. 为什么 Masked Attention 能防止未来信息泄露?
你可能会问:
即使有 Mask,计算 Attention 之前,我们不是还是用到了整个序列的 Q, K, V 吗?未来 token 的 Q, K, V 不是已经算出来了吗?
的确,每个 token 的 Q, K, V 是独立计算的,但 Masked Attention 确保了:
-
计算 Q, K, V 时,每个 token 只依赖于它自己的输入
-
只来自 token i,不会用到未来的信息
-
未来的 token 并不会影响当前 token 的 Q, K, V
-
-
Masked Softmax 阻止了未来 token 的影响
-
虽然 Q, K, V 都计算了,但 Masking 让未来 token 的注意力分数变为 0,确保计算出的 Attention 结果不包含未来信息。
-
最终,当前 token 只能看到过去的信息,未来的信息被完全屏蔽!
5. 训练时使用 Masked Attention 的必要性
Masked Attention 的一个关键作用是 让训练阶段和推理阶段保持一致。
-
训练时:模型学习如何根据 历史 token 预测 下一个 token,确保生成文本时符合自然语言流向。
-
推理时:模型生成每个 token 后,仍然只能访问过去的 token,而不会看到未来的 token。
如果 训练时没有 Masked Attention,模型会学习到“作弊”策略,直接利用未来信息进行预测。但在推理时,模型无法“偷看”未来的信息,导致生成质量急剧下降。
6. 结论
Masked Attention 是 LLM 训练的核心机制之一,其作用在于:
- 确保当前 token 只能访问过去的 token,不会泄露未来信息
- 让训练阶段与推理阶段保持一致,避免模型在推理时“失效”
- 利用因果 Mask 让 Transformer 具备自回归能力,学会按序生成文本
Masked Attention 本质上是 Transformer 训练过程中对信息流动的严格约束,它确保了 LLM 能够正确学习自回归生成任务,是大模型高质量文本生成的基础。