机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property    #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层  (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播   (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], []  #创建空列表,用于存储数据for x in X:  #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围,  range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score

 4.预测

import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429050.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW提高开发效率技巧----使用事件结构优化用户界面响应

事件结构(Event Structure) 是 LabVIEW 中用于处理用户界面事件的强大工具。通过事件驱动的编程方式,程序可以在用户操作时动态执行特定代码,而不是通过轮询(Polling)的方式不断检查界面控件状态。这种方式…

【学习笔记】 使用AD24完成相同电路的自动布线布局(相同模块布局布线ROOM布线快速克隆)

【学习笔记】 使用AD24完成相同电路的自动布线布局 一、适用基本条件二、基于ROOM的自动布局/布线的方法三、可能出现的报错四、ROOM自动布局的一些优点和缺点 当面对多个相同电路模块时,使用 ROOM 可以一次性对一个模块进行精心布局,然后将该布局快速复…

粒子向上持续瀑布动画效果(直接粘贴到记事本改html即可)

代码&#xff1a; 根据个人喜好修改即可 <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>宽粒子向上…

正则表达式匹配整数与浮点数失败与解决方案

正则表达式匹配整数与浮点数失败与解决方案 问题描述问题分析解决方案总结 问题描述 在处理数据的时候需要提取文本内整数与浮点数&#xff0c;这个时候想到使用正则表达式&#xff0c;咨询百度文心一言给出以下方案及参考代码 import re text "我有100元&#xff0c;…

解决mac下 Android Studio gradle 下载很慢,如何手动配置

抓住人生中的一分一秒&#xff0c;胜过虚度中的一月一年! 小做个动图开篇引题 前言 平时我们clone git 上项目&#xff0c;项目对应gradle版本本地没有&#xff0c;ide编译会自动下载&#xff0c;但是超级慢可能还下载失败&#xff0c;下面讲解下此问题如 如下图所示&#xff…

ML 系列:机器学习和深度学习的深层次总结(04)多元线性回归 (MLR)

图 1.多元线性回归与简单线性回归 一、说明 线性回归从一维推广到多维&#xff0c;这与单变量线性回归有很多不同&#xff0c;情况更加复杂&#xff0c;而在梯度优化也需要改成向量梯度&#xff0c;同时&#xff0c;数据预处理也成了必要步骤。 二、综述 多元线性回归是简单线性…

【AI算法岗面试八股面经【超全整理】——深度学习】

AI算法岗面试八股面经【超全整理】 概率论【AI算法岗面试八股面经【超全整理】——概率论】信息论【AI算法岗面试八股面经【超全整理】——信息论】机器学习【AI算法岗面试八股面经【超全整理】——机器学习】深度学习【AI算法岗面试八股面经【超全整理】——深度学习】CVNLP …

【RabbitMQ】⾼级特性

RabbitMQ ⾼级特性 1. 消息确认1.1 消息确认机制1.2 代码示例 2. 持久化2.1 交换机持久化2.2 队列持久化2.3 消息持久化 3. 发送⽅确认3.1 confirm确认模式3.2 return退回模式3.3 问题: 如何保证RabbitMQ消息的可靠传输? 4. 重试机制5. TTL5.1 设置消息的TTL5.2 设置队列的TTL…

vue3集成google第三方登陆

网上资源很多&#xff0c;但乱七八糟&#xff0c;踩坑几小时后&#xff0c;发现下面的方式没问题。 npm install vue3-google-login 插件文档&#xff1a;vue3-google-登录 (devbaji.github.io) 修改main.js import ./assets/main.css import { createApp } from vue impor…

医院伤员消费点餐限制———未来之窗行业应用跨平台架构

一、点餐上限 医院点餐上限具有以下几方面的意义&#xff1a; 1. 控制成本 - 有助于医院合理规划餐饮预算&#xff0c;避免食物的过度供应造成浪费&#xff0c;从而降低餐饮成本。 2. 保障饮食均衡 - 防止患者或陪护人员过度点餐某一类食物&#xff0c;有利于引导合…

【AcWing】873. 欧拉函数

#include<iostream> using namespace std;int main(){int n;cin>>n;while(n--){int x;cin>>x;int resx;for(int i2;i<x/i;i){if(x%i0){//resres*(1-1/i);整数1/i等于0&#xff0c;算不对且会溢出//以下几种都能ac//resres/i*(i-1);i*(1-1/i)i-1&#xff0…

[JavaEE] TCP协议

目录 一、TCP协议段格式 二、TCP确保传输可靠的机制 2.1 确认应答 2.2 超时重传 2.3 连接管理 2.3.1 三次握手 2.3.2 四次挥手 2.4 滑动窗口 2.4.1 基础知识 2.4.2 两种丢包情况 2.4.2.1 数据报已经抵达&#xff0c;ACK丢包 2.4.2.2 数据包丢包 2.5 流量控制…

音视频入门基础:AAC专题(7)——FFmpeg源码中计算AAC裸流每个packet的size值的实现

音视频入门基础&#xff1a;AAC专题系列文章&#xff1a; 音视频入门基础&#xff1a;AAC专题&#xff08;1&#xff09;——AAC官方文档下载 音视频入门基础&#xff1a;AAC专题&#xff08;2&#xff09;——使用FFmpeg命令生成AAC裸流文件 音视频入门基础&#xff1a;AAC…

Vue3:v-model实现组件通信

目录 一.性质 1.双向绑定 2.语法糖 3.响应式系统 4.灵活性 5.可配置性 6.多属性绑定 7.修饰符支持 8.defineModel使用 二.使用 1.父组件 2.子组件 三.代码 1.父组件代码 2.子组件代码 四.效果 一.性质 在Vue3中&#xff0c;v-model指令的性质和作用主要体现在…

AI健身之俯卧撑计数和姿态矫正-角度估计

在本项目中&#xff0c;实现了Yolov7-Pose用于人体姿态估计。以下是如何在Windows 11操作系统上设置和运行该项目的详细步骤。 环境准备 首先&#xff0c;确保您的计算机已经安装了Anaconda。Anaconda是一个开源的Python发行版本&#xff0c;它包含了conda、Python以及众多科…

【滑动窗口】算法总结

文章目录 滑动窗口算法总结1.暴力求解vs滑动窗口2.需要注意的细节问题 2.滑动窗口的基本模板1.非固定窗口大小的滑动窗口2.固定窗口大小的滑动窗口细节 滑动窗口算法总结 1.暴力求解vs滑动窗口 遇到那些可以转化成一个子数组的长度的问题时&#xff0c;往往需要用到双指针。 …

安全基础学习-AES128加密算法

前言 AES&#xff08;Advanced Encryption Standard&#xff09;是对称加密算法的一个标准&#xff0c;主要用于保护电子数据的安全。AES 支持128、192、和256位密钥长度&#xff0c;其中AES-128是最常用的一种&#xff0c;它使用128位&#xff08;16字节&#xff09;的密钥进…

探索人工智能绘制宇宙地图的实现

人工智能 (AI) 已成为了解世界的重要工具。现在&#xff0c;随着人们对太空探索的兴趣重新升温&#xff0c;人工智能也可能对其他世界产生同样的影响。 尽管经过了几十年的研究&#xff0c;科学家们对地球大气层以外的宇宙仍然知之甚少。绘制行星、恒星、星系及其在太空中的运…

python爬虫初体验(一)

文章目录 1. 什么是爬虫&#xff1f;2. 为什么选择 Python&#xff1f;3. 爬虫小案例3.1 安装python3.2 安装依赖3.3 requests请求设置3.4 完整代码 4. 总结 1. 什么是爬虫&#xff1f; 爬虫&#xff08;Web Scraping&#xff09;是一种从网站自动提取数据的技术。简单来说&am…

安卓14剖析SystemUI的ShadeLogger/LogBuffer日志动态控制输出dumpsy机制

背景&#xff1a; 看SystemUI的锁屏相关代码时候发现SystemUI有一个日志打印相关的方法调用&#xff0c;相比于常规的Log.i直接可以logcat查看方式还是比较新颖。 具体日志打印代码如下&#xff1a; 下面就来介绍一下这个ShadeLogger到底是如何打印的。 分析源码&#xff1…