机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用

目录

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

2.2验证上述封装的代码 

2.3 创建 添加了Bahdanau的decoder 

 2.4训练

 2.5预测

3.知识点个人理解 


 

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

class MultiHeadAttention(nn.Module):#初始化属性和方法def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):"""query_size_size: query_size的特征数featureskey_size: key_size的特征数featuresvalue_size: value_size的特征数featuresnum_hiddens:隐藏层的神经元的数量num_heads:多头注意力的header的数量dropout: 释放模型需要计算的参数的比例bias=False:没有偏差**kwargs : 不定长度的关键字参数"""super().__init__(**kwargs)#接收参数self.num_heads = num_heads#初始化注意力,    #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同self.attention = dltools.DotProductAttention(dropout)#初始化四个w模型参数self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):def transpose_qkv(X, num_heads):"""实现queries, keys, values的数据维度转化"""#输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)#这里,不能直接用reshape,需要索引维度,防止数据不能一一对应X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)   #将原维度的num_hiddens拆分成num_heads, -1,  -1相当于num_hiddens/num_heads的数值X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])  #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)def transpose_outputs(X, num_heads):"""逆转transpose_qkv的操作"""#此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, 查询数/键值对数量, num_heads,  num_hiddens/num_heads)return X.reshape(X.shape[0], X.shape[1], -1)  #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)#queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)#获取转换维度之后的queries, keys, values,queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)#若valid_len不为空,存在if valid_lens is not None:#将valid_lens重复数据self.num_heads次,在0维度上valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)#若为空,什么都不做,跳出if判断,继续执行其他代码#通过注意力函数获取输出outputs#outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)outputs = self.attention(queries, keys, values, valid_lens)#逆转outputs的维度outputs_concat = transpose_outputs(outputs, self.num_heads)return self.W_o(outputs_concat)

2.2验证上述封装的代码 

#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval()  #需要预测,加上
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.2, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])X = torch.ones((batch_size, num_queries, num_hiddens))  #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  #shape(2,6,100) attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

2.3 创建 添加了Bahdanau的decoder 

# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddenscontext = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

 2.4训练

# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)net = dltools.EncoderDecoder(encoder, decoder)dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 2.5预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('trouvez tom .', []), bleu 0.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430629.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenCV_距离变换的图像分割和Watershed算法详解

在学习watershed算法的时候,书写代码总会出现一些错误: 上述代码运行报错,显示OpenCV(4.10.0) Error: Assertion failed (src.type() CV_8UC3 && dst.type() CV_32SC1) in cv::watershed 查找资料:目前已解决 这个错…

在Windows系统上安装的 Arrow C++ 库

在Windows系统上安装的 Arrow C 库 正文第一步第二步第三步第四步注: 检查是否安装成功 吐槽 正文 第一步 git clone gitgithub.com:apache/arrow.git第二步 打开powershell (好像cmd也可以,不过我试了powershell中不报错,cmd中报错,不是很清楚为什么) 打开arrow的目录 cd …

java项目之基于springboot框架开发的景区民宿预约系统的设计与实现(源码+文档)

项目简介 基于springboot框架开发的景区民宿预约系统的设计与实现的主要使用者分为: 管理员的功能有:用户信息的查询管理,可以删除用户信息、修改用户信息、新增用户信息,根据公告信息进行新增、修改、查询操作等等。。 &#x1…

EasyExcel将数据库里面的数据生成excel文件

EasyExcel官方文档 1.在model模块导入依赖 <!-- 生成报表--> <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>4.0.3</version> </dependency> 2.修饰实体类 package…

【网络】TCP协议的简单使用

目录 echo_service server 单进程单线程 多进程 多线程 线程池 client echo_service_code echo_service 还是跟之前UDP一样&#xff0c;我们先通过实际的代码来实现一些小功能&#xff0c;简单的来使用TCP协议进行简单的通信&#xff0c;话不多说&#xff0c;我们先实现…

【结构型】树形结构的应用王者,组合模式

目录 一、组合模式1、组合模式是什么&#xff1f;2、组合模式的主要参与者&#xff1a; 二、优化案例&#xff1a;文件系统1、不使用组合模式2、通过组合模式优化上面代码优化点&#xff1a; 三、使用组合模式有哪些优势1、统一接口&#xff0c;简化客户端代码2、递归结构处理方…

UART配置流程

S3C2440A 的通用异步收发器&#xff08;UART&#xff09;配有3 个独立异步串行I/O&#xff08;SIO&#xff09;端口&#xff0c;每个都可以是基于中断或基于DMA 模式的操作。换句话说&#xff0c;UART 可以通过产生中断或DMA 请求来进行CPU 和UART 之间的数据传输。UART 通过使…

CUDA并行架构

一、CUDA简介 CUDA(Compute Unified Device Architecture)是一种由NVIDIA推出的通用并行计算架构&#xff0c;该架构使GPU(Graphics Processing Unit)能够对复杂的计算问题做性能速度优化。 二、串并行模式 高性能计算的关键是利用多核处理器进行并行计算。 串行模式&#…

DesignMode__unity__抽象工厂模式在unity中的应用、用单例模式进行资源加载

目录 抽象工厂模式 思维导图 接口&#xff08;抽象类&#xff09; 工厂接口 抽象产品类 抽象武器接口 抽象人物接口 具体工厂和具体产品 具体工厂 &#xff08;1&#xff09;产品接口&#xff0c;生成具体人物 &#xff08;2&#xff09;武器接口&#xff0c;生成具体…

vue3 vxe-grid 通过数据库返回的列信息,生成columns,并且其中有一列是img类型,进行slots的格式化处理。

1、一般我们写死的列信息的时候&#xff0c;会这样定义&#xff1a; 2、然后我们在template里面&#xff0c;这样这样写slots格式化部分&#xff1a; 这样表格中就会展示出一张图片&#xff0c;并且&#xff0c;我们点击了可以查看大图。 3、那么我们从数据库中返回的列&#…

maxwell 输出消息到 kafka

文章目录 1、kafka-producer2、运行一个Docker容器&#xff0c;该容器内运行的是Zendesk的Maxwell工具&#xff0c;一个用于实时捕获MySQL数据库变更并将其发布到Kafka或其他消息系统的应用3、进入kafka容器内部4、tingshu_album 数据库中 新增数据5、tingshu_album 数据库中 更…

RHCS认证-Linux(RHel9)-Ansible

文章目录 一、ansible 简介二 、ansible部署三、ansible服务端测试四 、ansible 清单inventory五、Ad-hot 点对点模式六、YAML语言模式七、RHCS-Ansible附&#xff1a;安装CentOS-Stream 9系统7.1 ansible 执行过程7.2 安装ansible&#xff0c;ansible-navigator7.2 部署ansibl…

Docker UI强大之处?

DockerUI是一款由国内开发者打造的优秀Docker可视化管理工具。它拥有简洁直观的用户界面&#xff0c;使得Docker主机管理、集群管理和任务编排变得轻松简单。DockerUI不仅能展示资源利用率、系统信息和更新日志&#xff0c;还提供了镜像管理功能&#xff0c;帮助用户高效清理中…

清华大学开源视频转文本模型——CogVLM2-Llama3-Caption

通常情况下&#xff0c;大多数视频数据并不附带相应的描述性文本&#xff0c;因此有必要将视频数据转换为文本描述&#xff0c;为文本到视频模型提供必要的训练数据。 CogVLM2-Caption 是一个视频字幕模型&#xff0c;用于为 CogVideoX 模型生成训练数据。 文件 使用 import i…

计算机毕业设计 校运会管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

青柠视频云——如何开启HTTPS服务?

前言 由于青柠视频云的语音对讲会使用到HTTPS服务&#xff0c;这里我们说一下如何申请证书以及如何在实战中部署并且配置使用。 一、证书申请 1、进入控制台 我们拿阿里云的免费个人证书为例&#xff0c;首先登录阿里云&#xff0c;在控制台找到数字证书管理服务&#xff0c;进…

音视频入门基础:FLV专题(4)——使用flvAnalyser工具分析FLV文件

一、引言 有很多工具可以分析FLV格式&#xff0c;这里推荐flvAnalyser。其支持&#xff1a; 1.FLV 文件分析&#xff08;Tag 列表、时间戳、码率、音视频同步等&#xff09;&#xff0c;HEVC(12)/AV1(13) or Enhanced RTMP v1 with fourCC(hvc1/av01)&#xff1b; 2.RTMP/HTT…

Linux突发网络故障常用排查的命令

测试环境 系统&#xff1a;Ubuntu 18硬件&#xff1a;单核2G ping 用于测试客户机和目标主机通信状况&#xff0c;是否畅通。以及测量通信的往返时间&#xff0c;判断网络质量的好坏。 它通过发送ICMP回显请求消息到目标主机&#xff0c;并等待返回的ICMP回显回复消息。 pin…

DeiT:Data-efficient Image Transformer(2020),基于新型蒸馏的ViT

Training data-efficient image transformers & distillation through attention&#xff1a;通过注意力训练数据高效的图像转换器和蒸馏 论文地址&#xff1a; https://arxiv.org/abs/2012.12877 代码地址&#xff1a; https://github.com/facebookresearch/deit 这篇论文…

k8s介绍-搭建k8s

官网&#xff1a;https://kubernetes.io/ 应用部署方式演变 传统部署&#xff1a;互联网早期&#xff0c;会直接将应用程序部署在物理机上 优点&#xff1a;简单&#xff0c;不需要其他技术的参与 缺点&#xff1a;不能为应用程序定义资源使用边界&#xff0c;很难合理地分配计…