目录
1.多头注意力(MultiAttentionn)的理念图
2.代码实现
2.1创建多头注意力函数
2.2验证上述封装的代码
2.3 创建 添加了Bahdanau的decoder
2.4训练
2.5预测
3.知识点个人理解
1.多头注意力(MultiAttentionn)的理念图
2.代码实现
2.1创建多头注意力函数
class MultiHeadAttention(nn.Module):#初始化属性和方法def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):"""query_size_size: query_size的特征数featureskey_size: key_size的特征数featuresvalue_size: value_size的特征数featuresnum_hiddens:隐藏层的神经元的数量num_heads:多头注意力的header的数量dropout: 释放模型需要计算的参数的比例bias=False:没有偏差**kwargs : 不定长度的关键字参数"""super().__init__(**kwargs)#接收参数self.num_heads = num_heads#初始化注意力, #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同self.attention = dltools.DotProductAttention(dropout)#初始化四个w模型参数self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):def transpose_qkv(X, num_heads):"""实现queries, keys, values的数据维度转化"""#输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)#这里,不能直接用reshape,需要索引维度,防止数据不能一一对应X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) #将原维度的num_hiddens拆分成num_heads, -1, -1相当于num_hiddens/num_heads的数值X = X.permute(0, 2, 1, 3) #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3]) #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)def transpose_outputs(X, num_heads):"""逆转transpose_qkv的操作"""#此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3) #X的shape=(batch_size, 查询数/键值对数量, num_heads, num_hiddens/num_heads)return X.reshape(X.shape[0], X.shape[1], -1) #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)#queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)#获取转换维度之后的queries, keys, values,queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)#若valid_len不为空,存在if valid_lens is not None:#将valid_lens重复数据self.num_heads次,在0维度上valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)#若为空,什么都不做,跳出if判断,继续执行其他代码#通过注意力函数获取输出outputs#outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)outputs = self.attention(queries, keys, values, valid_lens)#逆转outputs的维度outputs_concat = transpose_outputs(outputs, self.num_heads)return self.W_o(outputs_concat)
2.2验证上述封装的代码
#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval() #需要预测,加上
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.2, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False) )
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])X = torch.ones((batch_size, num_queries, num_hiddens)) #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) #shape(2,6,100) attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
2.3 创建 添加了Bahdanau的decoder
# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddenscontext = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights
2.4训练
# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)net = dltools.EncoderDecoder(encoder, decoder)dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
2.5预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000 i lost . => ("j'ai perdu .", []), bleu 1.000 he's calm . => ('trouvez tom .', []), bleu 0.000 i'm home . => ('je suis chez moi .', []), bleu 1.000
3.知识点个人理解