66 使用注意力机制的seq2seq_by《李沐:动手学深度学习v2》pytorch版

系列文章目录


文章目录

  • 系列文章目录
  • 动机
  • 加入注意力
  • 总结
  • 代码
    • 定义注意力解码器
    • 训练
    • 小结
    • 练习


我们来真的看一下实际应用中,key,value,query是什么东西,但是取决于应用场景不同,这三个东西会产生变化。先将放在seq2seq这个例子。

动机

机器翻译中,每个生成的词可能相关于源句子中不同的词。
比如RNN的最后一个时刻的最后一个输出,把所有东西压到一起,你看到的东西可能看不清楚。我想要翻译对应的词的时候,关注原句子中对应的部分,这就是将注意力机制放在seq2seq的动机,而之前的seq2seq模型中不能对此直接建模。
在这里插入图片描述

加入注意力

在这里插入图片描述
大体上就是,我在解码器RNN的输入一部分来自embedding,之前还有一部分是来自RNN最后一个时刻的最后一层作为上下文和embedding一起传进去。现在说最后一个时刻作为Context传进去不好,我应该根据我的现在预测值的不一样去选择不是最后一个时刻,可能是前面某个时刻对应的那些隐藏状态(经过注意力机制)作为输入。
编码器对每次词的输出作为key和value(它们成对的)。
解码器RNN对上一个词的输出是query。
注意力的输出和下一个词的词嵌入合并进入解码器。

总结

  1. Seq2seq中通过隐状态在编码器和解码器中传递信息
  2. 注意力机制可以根据解码器RNN的输出来匹配到合适的编码器RNN的输出来更有效的传递信息

代码

import torch
from torch import nn
from d2l import torch as d2l

定义注意力解码器

下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。
其实,我们只需重新定义解码器即可。
为了更方便地显示学习的注意力权重,
以下AttentionDecoder类定义了[带有注意力机制解码器的基本接口]。

#@save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self): #画图所需代码raise NotImplementedError

接下来,让我们在接下来的Seq2SeqAttentionDecoder类中[实现带有Bahdanau注意力的循环神经网络解码器]。
首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。
因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) #相比之前只是新增了这行代码self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args): #多了一个enc_valid_lens,之前不需要,现在需要知道英语的句子哪些是pad的。# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state X = self.embedding(X).permute(1, 0, 2) # 输出X的形状为(num_steps,batch_size,embed_size)outputs, self._attention_weights = [], []for x in X:# print("x.shape = "+str(x.shape))  #x.shape = torch.Size([4, 8])query = torch.unsqueeze(hidden_state[-1], dim=1)  # query的形状改为(batch_size,1,num_hiddens) 虽然query只有一个,但是要把query数量这个维度加进去,才能应用上一博客的函数# context的形状为(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # key和value是成对的,且都是encoder的output,enc_valid_lens是为了防止encoder中那些无效的pad的部分也被当作正常部分进行运算,encoder的长度是由num_steps确定好的定长的。query每次都会变。# print("context.shape = "+ str(context.shape))  #context.shape = torch.Size([4, 1, 16])# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights) # 将输出和注意力权重保存到列表中。# 全连接层变换后,outputs的形状为# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

接下来,使用包含7个时间步的4个序列输入的小批量[测试Bahdanau注意力解码器]。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

与 :numref:sec_seq2seq_training类似,
我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,
并对这个模型进行机器翻译训练。
由于新增的注意力机制,训练要比没有注意力机制的
:numref:sec_seq2seq_training慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 30, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.099, 10527.1 tokens/sec on cuda:0
<Figure size 350x250 with 1 Axes>

在这里插入图片描述
模型训练后,我们用它[将几个英语句子翻译成法语]并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => <unk> .,  bleu 0.000
i lost . => je suis parti .,  bleu 0.000
he's calm . => il est <unk> .,  bleu 0.658
i'm home . => je suis <unk> .,  bleu 0.512
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

训练结束后,下面通过[可视化注意力权重]
会发现,每个查询都会在键值对上分配不同的权重,这说明
在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

练习

  1. 在实验中用LSTM替换GRU。
  2. 修改实验以将加性注意力打分函数替换为缩放点积注意力,它如何影响训练效率?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/439138.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

平面电磁波的电场能量磁场能量密度相等,能量密度的体积分等于能量,注意电场能量公式也没有复数形式(和坡印廷类似)

1、电场能量密度和磁场能量密度相等(实数场算的) 下面是电场能量密度和磁场能量密度的公式&#xff0c;注意这可不是坡印廷定理。且电场能量密度没有复数表达式&#xff0c;即不是把E和D换成复数形式就行的。注意&#xff0c;一个矢量可以转化为复数形式&#xff0c;两个矢量做…

深入计算机语言之C++:C到C++的过度

&#x1f511;&#x1f511;博客主页&#xff1a;阿客不是客 &#x1f353;&#x1f353;系列专栏&#xff1a;从C语言到C语言的渐深学习 欢迎来到泊舟小课堂 &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注 一、什么是C C&#xff08;c plus plus&#xff…

使用NumPy进行线性代数的快速指南

介绍 NumPy 是 Python 中用于数值计算的基础包。它提供了处理数组和矩阵的高效操作&#xff0c;这对于数据分析和科学计算至关重要。在本指南中&#xff0c;我们将探讨 NumPy 中可用的一些基本线性代数操作&#xff0c;展示如何通过运算符重载和内置函数执行这些操作。 元素级…

Linux 文件 IO 管理(第四讲:软硬链接和动静态库)

Linux 文件 IO 管理&#xff08;第四讲&#xff1a;软硬链接和动静态库&#xff09; 软硬链接操作与现象软链接硬链接 解释软链接硬链接作用 动静态库初识静态库怎么做库&#xff08;开发角度&#xff09;怎么用库&#xff08;使用角度&#xff09;安装当前目录直接使用 动态库…

打卡第一天 B2005 字符三角形

今天是我打卡第一天&#xff0c;就做个水题吧(#^.^#) 题目描述 给定一个字符&#xff0c;用它构造一个底边长 55 个字符&#xff0c;高 33 个字符的等腰字符三角形。 输入格式 输入只有一行&#xff0c;包含一个字符。 输出格式 该字符构成的等腰三角形&#xff0c;底边长…

全新一区PID搜索算法+TCN-LSTM+注意力机制!PSA-TCN-LSTM-Attention多变量时间序列预测(Matlab)

全新一区PID搜索算法TCN-LSTM注意力机制&#xff01;PSA-TCN-LSTM-Attention多变量时间序列预测&#xff08;Matlab&#xff09; 目录 全新一区PID搜索算法TCN-LSTM注意力机制&#xff01;PSA-TCN-LSTM-Attention多变量时间序列预测&#xff08;Matlab&#xff09;效果一览基本…

今日凌晨,ChatGPT重磅更新!—— 我心目中的终极AGI界面

今日凌晨&#xff0c;ChatGPT重磅更新&#xff01;—— 我心目中的终极AGI界面 我心目中的终极 AGI 界面是一张空白画布&#xff08;canvas&#xff09;。 今日凌晨&#xff0c;OpenAI 发布 canvas&#xff0c;一个与 ChatGPT 合作写作和编程的新界面&#xff01; canvas&…

MySQL 启动失败 (code=exited, status=1/FAILURE) 异常解决方案

目录 前言1. 问题描述2. 查看错误日志文件2.1 确认日志文件路径2.2 查看日志文件内容 3. 定位问题3.1 问题分析 4. 解决问题4.1 注释掉错误配置4.2 重启 MySQL 服务 5. 总结结语 前言 在日常运维和开发过程中&#xff0c;MySQL数据库的稳定运行至关重要。然而&#xff0c;MySQ…

Leetcode—148. 排序链表【中等】

2024每日刷题&#xff08;171&#xff09; Leetcode—148. 排序链表 C实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), next(nullptr…

森林火灾检测数据集 7400张 森林火灾 带标注 voc yolo

森林火灾检测数据集 7400张 森林火灾 带标注 voc yolo 森林火灾检测数据集 名称 森林火灾检测数据集 (Forest Fire Detection Dataset) 规模 图像数量&#xff1a;共7780张图像。类别&#xff1a;仅包含一种类别——火源。 数据划分 训练集 (Train)&#xff1a;通常占总数据…

SpringBoot整合JPA详解

SpringBoot版本是2.0以上(2.6.13) JDK是1.8 一、依赖 <dependencies><!-- jdbc --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-jdbc</artifactId></dependency><!--…

Oracle SQL语句没有过滤条件,究竟是否会走索引??

答案是&#xff1a;可能走索引也可能不走索引&#xff0c;具体要看列的值可不可为null&#xff0c;Oracle不会为所有列的nullable属性都为Y的sql语句走索引。 例子&#xff1a; create table t as select * from dba_objects; CREATE INDEX ix_t_name ON t(object_id, objec…

9.30学习记录(补)

手撕线程池: 1.进程:进程就是运行中的程序 2.线程的最大数量取决于CPU的核数 3.创建线程 thread t1; 在使用多线程时&#xff0c;由于线程是由上至下走的&#xff0c;所以主程序要等待线程全部执行完才能结束否则就会发生报错。通过thread.join()来实现 但是如果在一个比…

SpringBoot助力校园资料分享:快速上手指南

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多学生、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常适…

多模态—文字生成图片

DALL-E是一个用于文字生成图片的模型&#xff0c;这也是一个很好思路的模型。该模型的训练分为两个阶段&#xff1a; 第一阶段&#xff1a;图片经过编码器编码为图片向量&#xff0c;当然我们应该注意这个过程存在无损压缩&#xff08;图片假设200*200&#xff0c;如果用one-h…

MATLAB|基于多主体主从博弈的区域综合能源系统低碳经济优化调度

目录 主要内容 程序亮点&#xff1a; 模型研究 一、综合能源模型 二、主从博弈框架 部分代码 结果一览 下载链接 主要内容 程序参考文献《基于多主体主从博弈的区域综合能源系统低碳经济优化调度》&#xff0c;采用了区域综合能源系统多主体博弈协同优化方…

Redis-预热雪崩击穿穿透

预热雪崩穿透击穿 缓存预热 缓存雪崩 有这两种原因 redis key 永不过期or过期时间错开redis 缓存集群实现高可用 主从哨兵Redis Cluster开启redis持久化aof&#xff0c;rdb&#xff0c;尽快恢复集群 多缓存结合预防雪崩&#xff1a;本地缓存 ehcache redis 缓存服务降级&…

国产RISC-V案例分享,基于全志T113-i异构多核平台!

RISC-V核心优势 全志T113-i是一款双核Cortex-A7@1.2GHz国产工业级处理器平台,并内置玄铁C906 RISC-V和HiFi4 DSP双副核心,可流畅运行Linux系统与Qt界面,并已适配OpenWRT系统、Docker容器技术。 而其中的RISC-V属于超高能效副核心,主频高达1008MHz,标配内存管理单元,可运…

程序员如何在 AI 时代保持核心竞争力

前言 随着 AIGC 大语言模型的不断涌现&#xff0c;AI 辅助编程工具的普及正在深刻改变程序员的工作方式。在这一趋势下&#xff0c;程序员面临着新的挑战与机遇&#xff0c;需要思考如何应对以保持并提升自身的核心竞争力。 目录 一、AI 对编程工作的影响 &#xff08;一&…