手撕MultiHeadAttention
类的代码,结合具体的例子来说明每一步的作用和计算过程。
往期文章:
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播
最适合小白入门的Transformer介绍
1. 初始化方法 __init__
def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
1.1参数解释
embed_size
:嵌入向量的维度,表示每个输入向量的大小。heads
:注意力头的数量。多头注意力机制将输入分割成多个“头”,每个头学习不同的特征。head_dim
:每个注意力头的维度大小,计算公式为embed_size // heads
。这意味着每个头处理的特征子集的大小。
1.2线性变换层
-
self.values
、self.keys
、self.queries
:- 这些是线性变换层,用于将输入的嵌入向量分别转换为值(Values)、键(Keys)和查询(Queries)。
- 每个线性层的输入和输出维度都是
self.head_dim
,因为每个头处理的特征子集大小为self.head_dim
。 - 使用
bias=False
是为了简化计算,避免引入额外的偏置项。
-
self.fc_out
:- 在多头注意力计算完成后,将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度
embed_size
。
- 在多头注意力计算完成后,将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度
2. 前向传播方法 forward
def forward(self, values, keys, query, mask):N = query.shape[0] # Batch sizevalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
2.1输入参数
values
、keys
、query
:- 这三个输入张量的形状通常为
(batch_size, seq_len, embed_size)
。 - 它们分别对应于值(Values)、键(Keys)和查询(Queries)。
- 这三个输入张量的形状通常为
mask
:- 用于遮蔽某些位置的注意力权重,避免模型关注到不应该关注的部分(例如,解码器中的未来信息)。
2.2多头注意力计算过程
2.2.1 将输入嵌入分割为多个头:
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
- 将输入的嵌入向量分割成
heads
个头,每个头的维度为self.head_dim
。 - 例如,如果
embed_size = 256
,heads = 8
,则self.head_dim = 32
,每个头处理 32 维的特征。 - 重塑后的形状为
(N, seq_len, heads, head_dim)
。
2.2.2 线性变换:
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
- 对每个头的值、键和查询分别进行线性变换。
- 这一步将输入特征投影到不同的子空间中,使得每个头可以学习不同的特征。
2.2.3计算注意力分数(Attention Scores):
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
- 使用
torch.einsum
计算查询和键之间的点积,得到注意力分数矩阵。 - 公式
nqhd,nkhd->nhqk
表示:n
:批量大小(Batch Size)。q
:查询序列的长度。k
:键序列的长度。h
:头的数量。d
:每个头的维度。
- 输出的
energy
形状为(N, heads, query_len, key_len)
。
2.2.4应用掩码(Masking):
if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))
- 如果提供了掩码,将掩码为 0 的位置的注意力分数设置为一个非常小的值(如
-1e20
),这样在后续的 softmax 计算中,这些位置的权重会趋近于 0。
2.2.5计算注意力权重:
attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim=3)
- 对注意力分数进行 softmax 归一化,得到注意力权重。
- 除以
sqrt(embed_size)
是为了缩放点积结果,避免梯度消失或爆炸。
2.2.6应用注意力权重:
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim
)
- 使用
torch.einsum
将注意力权重与值相乘,得到加权的值。 - 公式
nhql,nlhd->nqhd
表示:n
:批量大小。h
:头的数量。q
:查询序列的长度。l
:值序列的长度。d
:每个头的维度。
- 输出的
out
形状为(N, query_len, heads * self.head_dim)
。
2.2.7线性变换输出:
out = self.fc_out(out)
- 将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度
embed_size
。
3. 示例矩阵计算
假设:
embed_size = 4
heads = 2
head_dim = embed_size // heads = 2
- 输入序列长度为 3,批量大小为 1。
3.1输入张量
values = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32)
keys = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32)
query = torch.tensor([[[25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36]]], dtype=torch.float32)
mask = None
3.2重塑为多头
values = values.reshape(1, 3, 2, 2) # (N, value_len, heads, head_dim)
keys = keys.reshape(1, 3, 2, 2)
queries = query.reshape(1, 3, 2, 2)
3.3线性变换
假设线性变换层的权重为单位矩阵(简化计算),则:
values = self.values(values) # 不改变值
keys = self.keys(keys)
queries = self.queries(queries)
3.4计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
假设:
queries = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
keys = [[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]]
计算点积:
energy = [[[[1*13 + 2*14, 1*15 + 2*16], [1*17 + 2*18, 1*19 + 2*20]],
完整代码:
class MultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "嵌入尺寸需要被头部整除"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out
作者码字不易,觉得有用的话不妨点个赞吧,关注我,持续为您更新AI的优质内容。