Transformer 模型由 Vaswani 等人于2017年提出,主要应用于序列到序列的任务,最初应用于机器翻译。其核心思想是通过自注意力机制捕捉序列中的长期依赖关系,从而有效地进行任务建模
在著名的论文《Attention Is All You Need》中,编码器和解码器的层数 h 明确设定为6层。这一复合架构统称为Transformer架构
在本篇文章中,我们将逐步深入探讨 Transformer 模型的架构,以机器翻译任务为例,输入是一种语言的句子,输出是另一种语言的句子。我们将从整体架构入手,逐步拆解模型的各个组件
目录
1 Transformer 模型架构
1.1 编码器(Encoder)
1.2 解码器(Decoder)
2 自注意力机制 Self-Attention
2.1 为每个单词创建查询(Query)、键(Key)和值(Value)向量
2.2 使用查询向量对键向量进行评分
2.3 对注意力分数进行标准化
2.4 将值向量乘以标准化后的注意力分数并加权求和
2.5 进一步解析
2.6 流程小结
3 多头注意力机制 Multi-Head Attention
3.1 类比:模拟人类的注意力方式
3.2 工作原理
3.3 计算注意力分数并获取输出
3.4 合并所有头的输出
3.5 流程总结
4 多头注意力机制优化方案
4.1 优化
4.2 多查询注意力机制 Multi-Query Attention
4.3 分组查询注意力机制 Grouped Query Attention
4.4 MHA、GQA 和 MQA 的对比
1 Transformer 模型架构
整个 Transformer 模型可以简要总结为以下几步:
- 输入:通过位置编码将单词的位置信息嵌入到输入词向量中
- 编码器:输入句子经过自注意力层和前馈神经网络,逐步转换为丰富的上下文表示
- 解码器:解码器根据编码器的输出和之前生成的目标单词,通过自注意力层和编码器-解码器注意力层逐步生成目标语言句子
- 输出:解码器的最终输出通过线性层和 Softmax 层生成目标单词的概率分布
因此,Transformer 模型由以下两个主要部分组成:
- 编码器(Encoder):负责处理输入句子,将其转化为一个上下文丰富的表示
- 解码器(Decoder):根据编码器生成的上下文向量来生成目标语言句子
编码器和解码器都由多个相同的层堆叠而成,每一层都有其特定的结构和功能
1.1 编码器(Encoder)
编码器由多个相同的编码器层堆叠而成,每个编码器层包含两个主要子层:
-
自注意力层(Self-Attention Layer)
-
前馈神经网络(Feed-Forward Neural Network)
在 Transformer 模型的编码器中,输入句子首先经过自注意力层,然后进入前馈神经网络
1. 自注意力层(Self-Attention)
自注意力机制的核心思想是通过计算输入序列中不同单词之间的关系,来为每个单词分配不同的权重。具体来说,对于输入句子中的每个单词,计算与其他单词的相似度(注意力分数),从而为每个单词生成一个加权的表示。公式为:
其中,
- Q、K、V分别是查询(Query)、键(Key)和值(Value)矩阵,
- dk 是键向量的维度,
是为了防止数值过大而进行的缩放操作
- 通过计算所有单词对其他单词的注意力分数,最终生成加权的值(Value)向量
2. 前馈神经网络(Feed-Forward Neural Network)
在经过自注意力层的处理后,输出会通过一个前馈神经网络。该网络由两个全连接层组成,中间使用 ReLU 激活函数进行非线性映射。前馈神经网络的形式为:
其中,W1、W2 为权重矩阵,b1、b2 为偏置向量
1.2 解码器(Decoder)
解码器的结构与编码器类似,但多了一些额外的组件。每个解码器同样由三个主要部分组成:
- 自注意力层(Self-Attention)
- 编码器-解码器注意力层(Encoder-Decoder Attention)
- 前馈神经网络(Feed-Forward Neural Network)
其中自注意力层和前馈神经网络的结构与编码器中相同,在这两层之间是一个“编码器-解码器”注意力层,使用来自编码器压缩的上下文向量,让解码器不仅可以关注输入句子中的相关单词,还可以关注输出句子中的相关单词
因此,本文首先讲解自注意力机制 Self-Attention
2 自注意力机制 Self-Attention
在 Transformer 模型中,自注意力是核心组成部分之一,它通过计算输入序列中每个单词与其他单词之间的关系,生成包含上下文信息的表示。在计算过程中,输入单词会被映射为查询(Query)、键(Key)和值(Value)三个向量,然后通过这些向量之间的相似度来计算注意力分数。具体的计算可分为四个步骤:
2.1 为每个单词创建查询(Query)、键(Key)和值(Value)向量
假设输入序列中的第 i 个单词的词向量表示为 xi,,分别生成查询向量 Qi、键向量 Ki 和值向量 Vi。这些向量是通过将单词的词向量与不同的权重矩阵相乘得到的,具体公式为:
其中,WQ、WK、WV 是通过训练学习得到的权重矩阵,用于将输入的词向量映射到查询、键和值空间。这三个向量代表了不同的功能:
- 查询向量(Query):用于从其他单词中“查询”信息
- 键向量(Key):用于与查询向量进行匹配,计算它们之间的相关性
- 值向量(Value):是与键向量相关的实际信息,它将在后续步骤中根据注意力分数加权求和
其中,WQ、WK 和 WV 是通过训练学习得到的权重矩阵。通过这种方式,每个输入单词都会对应三个不同的向量,这些向量在后续的计算中将起到至关重要的作用
2.2 使用查询向量对键向量进行评分
接下来,对于输入序列中的每个单词 xi,计算其查询向量 Qi 与所有其他单词 xj 的键向量 Kj 之间的相关性。这是通过计算查询向量和键向量的内积来完成的。该内积值称为注意力分数(Attention Score),它衡量了查询单词与其他单词之间的相关性。计算公式为:
该操作对于每一对单词 xi 和 xj 都执行一次,结果是一组相关性分数,用于指示输入序列中每个单词与其他单词的关系,即查询单词与键单词的匹配程度。通过这种方式,我们能够为每个单词计算出它与其他单词之间的关系强度
2.3 对注意力分数进行标准化
内积计算的结果可能会很大,特别是当查询向量和键向量的维度 dk 较大时,从而导致数值不稳定,尤其是在训练过程中。为了避免这种情况,我们对注意力分数进行缩放操作。具体来说,将注意力分数除以键向量维度的平方根 。标准化后的注意力分数的计算公式为:
其中,dk 是键向量的维度。缩放操作的目的是防止内积的值过大,导致 Softmax 函数的梯度消失或爆炸,确保后续计算的数值稳定性
2.4 将值向量乘以标准化后的注意力分数并加权求和
经过缩放的注意力分数随后会通过 Softmax函数 进行归一化,使得所有注意力分数的和为1。Softmax 函数的作用是将原始的注意力分数转化为概率分布,从而表明每个单词对当前查询单词的贡献比例。Softmax函数的计算公式为:
Softmax归一化后得到的分数 将用于加权求和值向量 Vj,最终得到每个单词的输出表示。这一过程的计算公式为:
在这里,Softmax的输出值 表示了查询单词 xi 与所有键单词之间的相关性权重,而这些权重会加权值向量 Vj,得到每个位置的最终输出表示
2.5 进一步解析
1. 为什么需要三个向量来计算注意力分数?
查询(Query)、键(Key)和值(Value)三个向量的设计灵感来源于信息检索系统。在搜索引擎中,用户输入一个查询词,这个查询词就是查询(Query);搜索结果中的文章标题是键(Key),而文章的内容则是值(Value)。查询和标题的匹配度决定了该文章是否相关,最终返回的内容是文章的正文
同样,在自注意力机制中,查询向量用于“查询”相关的信息,键向量用于“匹配”查询信息,而值向量则包含了实际的信息。这种机制让模型能够动态地聚焦于输入序列中的相关部分。
在 Transformer 中,查询向量与键向量的作用类似于搜索引擎中的查询与标题的匹配过程。通过计算查询向量与键向量的相似度,我们可以找到哪些单词在当前上下文中更相关。然后,通过加权求和将这些相关的信息整合到一起,生成新的单词表示
2. 为什么要进行缩放?
在计算查询向量和键向量的内积时,内积的结果随着键向量维度 dk 的增大而变得越来越大。假设查询向量和键向量的元素是独立且均值为0、方差为1的随机变量,那么它们的内积的期望值为0,方差为 dk。因此,当 dk 较大时,内积的方差也会较大,这导致计算出的注意力分数会非常大,从而影响模型的训练
为了解决这一问题,将内积结果除以 ,这样可以保持注意力分数的方差在一个合理范围内,从而避免数值不稳定
3. Softmax函数与梯度稳定性
在注意力分数计算中,Softmax函数通过将分数归一化成概率分布,使得所有的分数和为1
Softmax的输出会用作加权求和值向量的权重。如果注意力分数的差异过大,Softmax的输出可能会非常不均匀,导致某些分数接近0或接近1,从而导致梯度消失或梯度爆炸的情况。通过对注意力分数进行缩放,我们可以平滑Softmax的输出,确保梯度的稳定性,防止训练过程中出现数值不稳定的问题。
4. 更直观理解
为了更直观地理解注意力分数,我们可以通过一个具体的例子来进行阐述。假设我们正在进行一句话的机器翻译任务,原句为:
英文:"The animal didn’t cross the street because it was too tired."
中文:"动物没有过马路,因为它太累了。"
在翻译过程中,问题是 “it” 指的到底是什么?是 "animal" 还是 "street"
对于人类来说,这个问题通常不难回答。我们很自然地理解,"it" 指的是 "animal",即“动物”才是句子中 "it" 的指代对象。然而,对于机器来说,这个问题可能会复杂得多,因为计算机在理解上下文和语义关系时,需要根据输入的上下文信息做出判断,而这种判断依赖于捕捉句子中不同单词之间的关系
此时,自注意力机制发挥了重要作用。当模型处理 "it" 时,注意力机制能够帮助模型通过计算 "it" 与其他单词(如 "animal" )的相关性来做出正确的决策。通过计算注意力分数,模型能够在处理 "it" 时将更多的权重分配给 "animal"
可以通过 Tensor2Tensor Notebook 交互式地查看训练后的模型输出,观察模型在翻译过程中如何根据不同的注意力分数加权每个单词的影响。经过训练后,模型能够在遇到 "it" 时与 "animal" 建立更强的联系,而不是与 "street" 产生关联
2.6 流程小结
在自注意力机制中,每个单词都对应查询(Query)、键(Key)和值(Value)三个向量。在理论上,我们是针对每个单词单独进行计算的,即每个单词都有各自的查询、键和值向量。然而,在实际的实现中,通常不会单独对每个单词进行操作,而是将整个输入序列组织成一个矩阵来处理。
在实际代码中,整个自注意力计算的流程通常是如下所示:
-
输入矩阵:将输入序列表示为矩阵 X
-
计算查询、键、值矩阵:通过权重矩阵将输入矩阵 X 转换为查询矩阵 Q、键矩阵 K 和值矩阵 V
-
计算注意力分数矩阵:通过查询矩阵与键矩阵的内积计算注意力分数矩阵 A
-
Softmax 归一化:对注意力分数矩阵 A 进行 Softmax 归一化
-
加权求和值矩阵:将归一化后的注意力分数矩阵与值矩阵 V 相乘,得到输出矩阵 Z
最终,矩阵 Z 将被传递到前馈神经网络中,进一步进行处理
具体而言:
1. 输入序列矩阵
假设输入序列为 X={x1,x2,...,xn},其中 xi 表示输入序列中的第 i 个单词的词向量。我们将整个输入序列表示成一个矩阵 ,其中 n 是输入序列中的单词数量,d 是词向量的维度
2. 通过矩阵计算查询、键和值
接下来,通过权重矩阵 WQ,WK,WV 将输入矩阵 X 转换为查询矩阵 Q、键矩阵 K 和值矩阵 V。具体来说,通过矩阵乘法来完成这一转换,公式为:
其中, 是训练过程中学习得到的权重矩阵,分别用于将输入的词向量映射到查询空间、键空间和值空间,得到三个矩阵:
- 查询矩阵
- 键矩阵
- 值矩阵
这些矩阵中的每一行对应输入序列中每个单词的查询、键和值向量
3. 计算注意力分数矩阵
接下来,计算注意力分数矩阵,表示查询和键之间的相关性。注意力分数矩阵 A 是通过查询矩阵 Q 与键矩阵 K 的内积来获得的。矩阵 A 中的每一个元素 Aij 表示第 i 个单词的查询向量与第 j 个单词的键向量之间的相关性。计算公式为:
其中,KT 是键矩阵的转置。这个操作将会得到一个大小为 n×n 的矩阵,表示输入序列中每个单词与其他单词之间的相关性
4. Softmax 归一化
为了将注意力分数转化为概率分布,我们对注意力分数矩阵 A 进行 Softmax 归一化。Softmax 会将每一行的注意力分数转换为一个概率分布。计算公式为:
这里的 Softmax 是对每一行进行操作,使得每个单词对于其他单词的注意力权重之和为1。最终得到的矩阵 softmax(A) 是一个大小为 n×n 的矩阵,表示每个单词对其他单词的“注意力权重”
5. 计算最终的输出矩阵
最后,将归一化后的注意力分数矩阵与值矩阵 V 相乘,得到最终的输出矩阵 Z。该输出矩阵表示了经过注意力加权后的每个单词的表示,公式为:
矩阵 是经过自注意力机制计算后得到的输出矩阵,将被传递到后续的前馈神经网络中
3 多头注意力机制 Multi-Head Attention
在 Transformer 模型的论文中,还提出了一种增强注意力的设计——多头注意力(Multi-Head Attention,MHA)机制。这一机制是对注意力层的进一步完善,旨在让模型可以同时关注来自不同表示子空间的信息。多头注意力机制的核心思想在于能够将信息从多个不同的角度进行并行处理,进而提高模型对信息的捕捉能力
3.1 类比:模拟人类的注意力方式
为了帮助理解多头注意力机制,我们可以借用一个类比。当我们阅读一篇文章时,我们的注意力并不是均匀地分布在整篇文章上。例如,标题和粗体文字通常会比正文部分更加引起我们的注意,而颜色鲜艳的文字(如红色的标题)也比那些黑色的正文更加吸引视线。这里,“标题”和“正文”代表了不同的表示子空间,而字体和颜色是两个独立的信息维度。如果我们能够同时关注这两个维度,结合字体和颜色的信息,那么我们可以更好地定位文章中的关键内容
在多头注意力机制中,模型正是通过这种方式进行信息处理。每个“头”可以从不同的角度(即不同的表示子空间)捕捉信息。这种设计让模型能够从多个不同的特征子空间中综合提取关键信息,从而更全面地理解输入数据,避免了仅依赖单一特征所带来的偏见
3.2 工作原理
为了实现多头注意力机制,模型需要维护多个查询(Query)、键(Key)和值(Value)权重矩阵。假设我们使用了 h 个注意力头(在标准 Transformer 模型中通常为8个),那么每个头都会有一组独立的权重矩阵 ,其中 i∈[1,h],每一组权重矩阵将会把输入特征向量映射到不同的表示子空间中
对于每一个头,输入矩阵 X 会分别与对应的权重矩阵进行矩阵乘法,得到各自的查询、键和值矩阵:
其中,,表示第 i 个头的查询、键和值矩阵,其中 n 是输入序列的长度,dk 是每个头的查询(或键、值)向量的维度
3.3 计算注意力分数并获取输出
接下来,对每个头计算注意力分数矩阵并进行 Softmax 归一化。对于每个头,注意力分数矩阵 计算公式为:
然后,通过 Softmax 对每个分数进行归一化:
最后,将注意力分数矩阵与对应的值矩阵 相乘,得到每个头的输出矩阵
3.4 合并所有头的输出
在多头注意力机制中,每个头会产生一个独立的输出矩阵 ,但这些输出矩阵不能直接输入到后续的前馈神经网络中。因此,我们需要将这些输出矩阵合并成一个单一的矩阵。最常用的做法是将每个头的输出按列拼接起来:
这里,Zconcat 是一个大小为 n×(h⋅dk) 的矩阵,其中 h 是头的数量,dk 是每个头的维度。接下来,模型会用一个线性变换矩阵 WO 对拼接后的输出进行投影,将它映射回原始的输出空间:
这样,最终得到的矩阵 就是经过多头注意力机制处理后的输出矩阵,它将作为输入传递给前馈神经网络
3.5 流程总结
- 多个头的权重矩阵:每个注意力头拥有一组独立的查询、键和值的权重矩阵
- 矩阵计算:对于每个头,输入序列通过相应的权重矩阵得到查询、键和值矩阵
- 计算注意力分数:通过查询矩阵和键矩阵的内积计算注意力分数,并进行 Softmax 归一化
- 加权求和值矩阵:使用归一化后的注意力分数对值矩阵进行加权求和,得到每个头的输出
- 拼接与线性变换:将所有头的输出拼接成一个矩阵,并通过线性变换得到最终的输出
4 多头注意力机制优化方案
在 Transformer 模型中,多头注意力机制是由多个点积注意力模块(Dot-Product Attention)组合而成。每个点积注意力模块独立计算自己的注意力分数,并生成相应的输出。可以将多头注意力的结构表示为:
其中, 代表第 i 个头的计算过程,WO 是最终的输出线性变换矩阵,h 是注意力头的数量
多头中的 h 个点积注意力模块是可以并行计算的,因为它们之间没有依赖关系。这种并行性使得多头注意力机制在处理大量数据时能够显著提升计算效率
4.1 优化
尽管多头注意力机制在实践中表现出色,提升了模型的效果,但在原始的 Transformer 论文中并没有给出完备的理论解释。研究人员并未从理论层面彻底证明多头注意力的优越性,而是通过大量实验发现其效果优于传统方法。实际上,这种通过实验验证的新思路,在AI研究中非常常见,许多突破性的创新往往来源于研究人员凭借敏锐的科研意识和实验直觉,提出新颖的研究方向
这些方向虽然有时缺乏完善的理论支持,但经过实验验证后,仍然可以有效地提高模型性能。而这些未被完全解释的设计也为后续的研究提供了继续深入探讨的空间
MQA 和 GQA 就是其中的两种优化方案
4.2 多查询注意力机制 Multi-Query Attention
虽然多头注意力机制有效提升了模型的表达能力,但它也带来了一定的存储和计算开销。具体来说,每个头都需要独立存储自己的查询(Query)、键(Key)和值(Value)矩阵,这会占用大量的内存空间,尤其是在模型隐藏层维度较大的情况下,这一问题更为突出
为了解决这一问题,研究者提出了多查询注意力(MQA)机制。这一机制的核心思想是查询(Query)保持原有的多头设计,而键(Key)和值(Value)则共享一个头。在这种设计中,所有的查询头共享一组键值对矩阵(KV 矩阵)。因此,这种设计被称为多查询注意力
尽管这种设计可能在某些情况下对模型性能产生微小的影响,但基于其显著的内存优化效果,性能的轻微下降是可以接受的。实验表明,MQA 模型通常可以提高30%到40%的处理效率,主要原因如下:
- 减少 KV 缓存的大小:在 MQA 中,键(Key)和值(Value)矩阵仅需存储一组,而不是为每个查询头存储独立的矩阵。这大大降低了内存占用。
- 减少内存读取:由于所有查询头共享一组 KV 矩阵,内存读取的数据量减少,从而降低了计算单元的等待时间,提高了计算效率。
- 提高显存利用率:由于 KV 矩阵的大小缩小,显存中需要保存的张量也相应减少,这为增大批处理量提供了空间,进一步提高了显存的利用率。
4.3 分组查询注意力机制 Grouped Query Attention
除了 MQA 之外,还有一种折中的解决方案,即分组查询注意力(GQA)机制。GQA 机制可以看作是 MHA 与 MQA 的混合模型,其设计目的是在不大幅牺牲性能的前提下,尽可能地获得MQA模型的推理加速优势
在 GQA 模型中,不是所有的查询头都共享一组 KV,而是将一定数量的查询头分组,共享同一组 KV。具体来说,GQA 设计将查询头分为若干组,每组查询头共享一个 KV 矩阵。这种设计有效平衡了性能和推理速度,尤其适用于需要在性能和效率之间做出权衡的场景
4.4 MHA、GQA 和 MQA 的对比
多头注意力机制(MHA)、分组查询注意力(GQA)和多查询注意力(MQA)这三种模型各自具有不同的优缺点,适用于不同的应用场景。通过实验验证,这些模型的效果对比如下表所示。在实际应用中,GQA 模型通常表现得相对较好,其推理速度较快且能维持较好的性能,而 MHA 的效果稍逊色一些。MQA 虽然在处理效率上有显著的提升,但其性能相比 GQA 略有下降
Llama 2论文中MHA、GQA和MQA在不同任务上的效果
顺便提一下,Llama 3 的 70B 参数版本已默认采用 GQA 技术,相比 Llama 2 的 MHA,推理速度提升 30% 且内存占用降低,同时保持模型质量接近 MHA