4.11.seq2seq 序列到序列学习

序列到序列学习(seq2seq)

​ 使用两个循环神经网络的编码器和解码器,应用于序列到薛烈类的学习任务。

在这里插入图片描述

​ 在图中,特定的"<eos>"表示序列结束词元。一旦输出序列生成此词元,模型就会停止预测。在循环神经网络解码器的初始化时间步中,有两个特定的设计决定:

​ 首先,特定的"<bos>"表示序列开始词元,它是解码器的输入序列的第一个词元。

​ 其次使用循环神经网络编码器最终的隐状态来初始化解码器的隐状态。

​ 编码器是一个RNN,读取输入句子(可以是双向,因为encode不用预测,有上下文的)

​ 解码器使用另一个RNN来输出

1.衡量生成序列的好坏的BLEU

p n p_n pn是预测所有 n − g r a m n-gram ngram的精度,标签序列A B C D E F 和预测序列A B B C D,有 p 1 = 4 5 , p 2 = 3 4 , p 3 = 1 3 , p 4 = 0 p_1=\frac 45,p_2 =\frac 34 ,p_3 = \frac 13 ,p_4 =0 p1=54,p2=43,p3=31,p4=0

n-gram就是连续几个字母预测对,比如 p 2 p_2 p2,就是连续2个,那么有4个匹配 AB,BB,BC,BD,其中有三个是对的,则为3/4

​ BLEU定义为:
e x p ( m i n ( 0 , 1 − l e n l a b e l l e n p r e d ) ) Π n = 1 k p n 1 2 n exp(min(0,1-\frac{len_{label}}{len_{pred}})){\large\Pi ^k_{n=1}p_n^{\frac{1}{2^n}}} exp(min(0,1lenpredlenlabel))Πn=1kpn2n1
m i n min min部分是用于惩罚过短的预测,而指数为 1 2 n \frac {1}{2^n} 2n1意味着长匹配有更高的权重

2.代码实现

在这里插入图片描述

2.1 编码器

​ 使用嵌入层(embedding layer) 来获得输入序列中每个词元的特征向量。 嵌入层的权重是一个矩阵, 其行数等于输入词表的大小(vocab_size), 其列数等于特征向量的维度(embed_size)。 对于任意输入词元的索引i, 嵌入层获取权重矩阵的第i行(从0开始)以返回其特征向量。并选择使用多层门控循环单元来实现编码器

class Seq2SeqEncoder(d2l.Encoder):"""用于序列到序列学习的循环神经网络编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)# 嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,dropout=dropout)def forward(self, X, *args):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X)# 在循环神经网络模型中,第一个轴对应于时间步X = X.permute(1, 0, 2)# 如果未提及状态,则默认为0output, state = self.rnn(X)# output的形状:(num_steps,batch_size,num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)return output, stateencoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
output, state = encoder(X)
print(output.shape) #(时间步,batch_size,隐藏层大小)
print(state.shape) #(num_layers,时间步,隐藏层大小)

2.2 解码器

​ 编码器输出的上下文变量 c c c对整个输入序列及进行编码。来自训练数据集的输出序列 y 1 , y 2 , ⋯ , y T ′ y_1,y_2,\cdots,y_{T'} y1,y2,,yT,对于每个时间步 t ′ t' t(与输入序列或编码器的时间步 t t t不同),解码器输出 y t ′ y_{t'} yt的概率取决于先前的输出子序列 y 1 , ⋯ , y t ′ − 1 y_1,\cdots,y_{t'-1} y1,,yt1和上下文变量 c c c,即 P ( y t ′ ∣ y 1 , ⋯ , y t ′ − 1 , c ) P(y_{t'}|y_1,\cdots,y_{t'-1},c) P(yty1,,yt1,c)

​ 为了在序列上模型化这种条件概率,可以使用另一个循环神经网络作为解码器,在输出序列上的任意时间步 t ′ t' t,循环神经网络将来自上一时间步的输出 y t ′ − 1 y_{t'-1} yt1和上下文变量 c c c作为其输入,然后在当前时间步将它们和上一隐状态 s t ′ − 1 s_{t'-1} st1转换为隐状态 s t ′ s_{t'} st

​ 即 s t ′ = g ( y t ′ − 1 , c , s t ′ − 1 ) s_{t'} =g(y_{t'-1},c,s_{t'-1}) st=g(yt1,c,st1)

​ 获得当前步的隐状态,随后就做一个softmax来计算在时间步 t ′ t' t时输出 y t ′ y_{t'} yt的条件概率分布 P ( y t ′ ∣ y 1 , ⋯ , y t ′ − 1 , c ) P(y_{t'}|y_1,\cdots,y_{t'-1},c) P(yty1,,yt1,c)

​ 由于我们将直接使用编码器最后一个隐藏状态来初始化解码器的隐状态,所以要求编码器和解码器有相同数量的层和隐藏单元。

​ 为了更好的包含编码后的输入序列的信息,上下文变量(编码器的输出)在所有的时间步都要与解码器的输入进行拼接。最后使用全连接层来变换隐状态。

class Seq2SeqDecoder(d2l.Decoder):"""用于序列到序列学习的循环神经网络解码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqDecoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 输入是embed_size,还有编码器的输出num_hiddens,两个cat在一起作为rnn的输入self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, *args):return enc_outputs[1] # 0是编码器的output,1是state,用编码器的隐状态来初始化def forward(self, X, state):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X).permute(1, 0, 2)# 广播context,使其具有与X相同的num_steps# state[-1]是最后一层的隐状态输出,repeat时间步数次,变成三维的context = state[-1].repeat(X.shape[0], 1, 1)# 和输入拼接X_and_context = torch.cat((X, context), 2)output, state = self.rnn(X_and_context, state)output = self.dense(output).permute(1, 0, 2)# output的形状:(batch_size,num_steps,vocab_size)# state的形状:(num_layers,batch_size,num_hiddens)return output, statedecoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
print(output.shape, state.shape)

在这里插入图片描述

​ 说得很对,seq2seq这样更新就使得encoder的输出发生了改变,随意我们简单处理一下,将decoder的state和encoder输出状态分开,只更新state,就对了

2.3 损失函数

​ 在每个时间步,解码器都预测了输出词元的概率分布,随后可以使用softmax来获得分布,并通过交叉熵损失函数来进行优化。

​ 但注意到,某些特定的填充词被添加了,这可以使得不同长度的序列可以以相同形状的小批量加载,但我们应该将这些填充词元的预测排除在损失函数的计算之外。

​ 使用零值化屏蔽不相关的项:

def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项"""maxlen = X.size(1)mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]X[~mask] = valuereturn XX = torch.tensor([[1, 2, 3], [4, 5, 6]])
sequence_mask(X, torch.tensor([1, 2])) # 保留前1项,前两项'''tensor([[1, 0, 0],[4, 5, 0]])'''
# 使用该函数屏蔽最后几个轴上的所有项,其实也是选择第二维
X = torch.ones(2, 3, 4)
sequence_mask(X, torch.tensor([1, 2]), value=-1)'''
tensor([[[ 1.,  1.,  1.,  1.],[-1., -1., -1., -1.],[-1., -1., -1., -1.]],[[ 1.,  1.,  1.,  1.],[ 1.,  1.,  1.,  1.],[-1., -1., -1., -1.]]])
'''

​ 随后,通过扩展softmax交叉损失函数来遮蔽不相关的预测:

#@save
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):"""带遮蔽的softmax交叉熵损失函数"""# pred的形状:(batch_size,num_steps,vocab_size)# label的形状:(batch_size,num_steps)# valid_len的形状:(batch_size,)def forward(self, pred, label, valid_len):weights = torch.ones_like(label) weights = sequence_mask(weights, valid_len) #生成weight,0是不要的self.reduction='none'unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label) #正常计算,得到的是一个向量weighted_loss = (unweighted_loss * weights).mean(dim=1)# 矩阵对应乘一下,再取均值return weighted_lossloss = MaskedSoftmaxCELoss()
print(loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),torch.tensor([4, 2, 0])))

2.4 训练

#@save
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):"""训练序列到序列模型"""def xavier_init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)if type(m) == nn.GRU:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])net.apply(xavier_init_weights)net.to(device)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = MaskedSoftmaxCELoss()net.train()animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs])for epoch in range(num_epochs):timer = d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失总和,词元数量for batch in data_iter:optimizer.zero_grad()X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学,最后一个单词不能再接着预测了,所以没用,直接拿掉,并把开始标志放到每一句的最前面Y_hat, _ = net(X, dec_input, X_valid_len)l = loss(Y_hat, Y, Y_valid_len)l.sum().backward()      # 损失函数的标量进行“反向传播”d2l.grad_clipping(net, 1)num_tokens = Y_valid_len.sum()optimizer.step()with torch.no_grad():metric.add(l.sum(), num_tokens)if (epoch + 1) % 10 == 0:animator.add(epoch + 1, (metric[0] / metric[1],))print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} 'f'tokens/sec on {str(device)}')embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,dropout)
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

2.5 预测

#@save
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,device, save_attention_weights=False):"""序列到序列模型的预测"""# 在预测时将net设置为评估模式net.eval()#将源句子转换为小写并分词,然后查找每个词元在词汇表中的索引,并添加 <eos> 标志。src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]# 记录每个句子的有效长度,注意力机制需要使用enc_valid_len = torch.tensor([len(src_tokens)], device=device)#截断、填充src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])# 添加批量轴enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)enc_outputs = net.encoder(enc_X, enc_valid_len)dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)# 添加批量轴dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)output_seq, attention_weight_seq = [], []for _ in range(num_steps):Y, dec_state = net.decoder(dec_X, dec_state)# 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入dec_X = Y.argmax(dim=2)pred = dec_X.squeeze(dim=0).type(torch.int32).item()# 保存注意力权重(稍后讨论)if save_attention_weights:attention_weight_seq.append(net.decoder.attention_weights)# 一旦序列结束词元被预测,输出序列的生成就完成了if pred == tgt_vocab['<eos>']:breakoutput_seq.append(pred)return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

2.5 BLEU的实现

def bleu(pred_seq, label_seq, k):  #@save"""计算BLEU"""pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')len_pred, len_label = len(pred_tokens), len(label_tokens)score = math.exp(min(0, 1 - len_label / len_pred))for n in range(1, k + 1):num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i + n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i + n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i + n])] -= 1score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))return scoreengs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/394254.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JS+CSS案例:可适应上下布局和左右布局的菜单(含二级菜单)

今天,我给大家分享一个原创的CSS菜单,整个菜单全由CSS写成,仅在切换布局时使用JS。合不合意,先看看效果图。 本例图片 接下来,我来详细给大家分享它的制作方法。 文件夹结构 因为涉及到了样式表切换,所以,你需要借鉴一下我的文件夹结构。 CSS文件夹: reset.css 用于…

维吉尼亚密码加解密实现(python)

维吉尼亚密码 原理 维吉尼亚密码&#xff08;Vigenere&#xff09;是使用一系列凯撒密码组成密码字母表的加密算法&#xff0c;属于多表密码的一种简单形式。 下面给出一个例子 明文&#xff1a;come greatwall 密钥&#xff1a;crypto首先&#xff0c;对密钥进行填充使其长…

【算法】普里姆算法解决修路问题

应用场景——修路问题 1.某地有 7 个村庄&#xff08;A&#xff0c;B&#xff0c;C&#xff0c;D&#xff0c;E&#xff0c;F&#xff0c;G&#xff09;&#xff0c;现在需要修路把 7 个村庄连通 2.各个村庄的距离用边线表示&#xff08;权&#xff09;&#xff0c;比如 A - …

ORM工具之SQLAlchemy

SQLAlchemy是Python编程语言下的一款开源软件。提供了SQL工具包及对象关系映射&#xff08;ORM&#xff09;工具&#xff0c;使用MIT许可证发行。 SQLAlchemy“采用简单的Python语言&#xff0c;为高效和高性能的数据库访问设计&#xff0c;实现了完整的企业级持久模型”。SQL…

从 Pandas 到 Polars 四十四:Polars 和 数据可视化库Seaborn

在我对Matplotlib感到沮丧并发表帖子时&#xff0c;我的朋友让我试试Seaborn库。近年来我一直在使用Altair&#xff0c;因此并没有过多考虑Seaborn。然而&#xff0c;Seaborn的新界面给我留下了深刻印象&#xff0c;并且我很高兴地发现&#xff0c;Seaborn将直接接受Polars的Da…

【web安全】权限漏洞之未授权访问

一.Jenkins未授权访问漏洞 步骤一&#xff1a;使用以下fofa语法进行搜索 port"8080" && app"JENKINS" && title"Dashboard [Jenkins]" 步骤二&#xff1a;进入执行页面http://xxx.xxx.xxx.xxx:xxxx/manage/script/index.php 执…

Linux下自动监控进程运行状态

目录 背景应用举例1、使用crontab脚本监控服务2、使用shell脚本监控服务2.1 编写自定义监控脚本2.2 运行脚本 背景 假设有一个服务需要长期运行&#xff0c;但可能会由于某种原因导致服务意外停止&#xff0c;不能及时发现&#xff0c;某天来到公司后发现出问题了才意识到服务…

(Qt) QThread 信号槽所在线程

文章目录 &#x1f481;&#x1f3fb;前言&#x1f481;&#x1f3fb;Code&#x1f481;&#x1f3fb;‍♂️Code&#x1f481;&#x1f3fb;‍♂️环境 &#x1f481;&#x1f3fb;当前线程信号&#x1f481;&#x1f3fb;‍♂️默认效果&#x1f481;&#x1f3fb;‍♂️Qt::…

最新CSS3伪类和伪元素详解

第4章 伪类和伪元素 4.1结构伪类 E:first-child{},第一个元素 样式&#xff1a; p:first-child {color: red; } <div><p>Lorem ipsum</p><p>Dolor sit amet.</p> </div> 4.1.1nth-*伪类 以计数为基础的&#xff0c;默认情况下&…

探索下一代互联网协议:IPv6的前景与优势

探索下一代互联网协议&#xff1a;IPv6的前景与优势 文章目录 探索下一代互联网协议&#xff1a;IPv6的前景与优势**IPv6 的特点****IPv6的基本首部****IPv6的地址****总结** 互联网的核心协议&#xff1a;从IPv4到IPv6 互联网的核心协议IP&#xff08;Internet Protocol&#…

Docker Deskpot出现Docker Engine Stopped的解决历程

前提&#xff1a;我的操作系统是Win11家庭版, Docker Descktop下载的是最新版&#xff08;此时是4.30.0&#xff09; 出现了如图所示的问题“Docker Engine Stopped”,个人认为解决问题的关键是第四点&#xff0c;读者可以直接看第四点&#xff0c;如果只看第四点就成功解决&am…

python开发上位机 - PyCharm环境搭建、安装PyQt5及工具

目录 简介&#xff1a; 一、安装PyCharm 1、下载 PyCharm 2、PyCharm安装 1&#xff09;配置安装目录 2&#xff09;安装选项 3、问题及解决方法 二、安装PyQt5 1、打开 Pycharm&#xff0c;新建 Project 2、安装 pyqt5 3、安装很慢怎么办&#xff1f; 4、安装 pyq…

RHCSA第三次作业

磁盘管理及分区&#xff1a; [rootMYyyy ~]# fdisk /dev/sda [rootMYyyy ~]# lsblk NAME MAJ:MIN RM SIZE RO TYPE MOUNTPOINTS sda 8:0 0 10G 0 disk └─sda1 8:1 0 9G 0 part sdb 8:16 0 30G 0 disk sdc …

docker 部署 mysql8

命令 docker run --restartalways --name mysql8 -v /data/mysql/conf:/etc/mysql -v /data/mysql/data:/var/lib/mysql -v /data/mysql/log:/var/log -v /data/mysql/mysql-files:/var/lib/mysql-files -p 3308:3306 -e MYSQL_ROOT_PASSWORD123456 -d mysql:8 \解释 --rest…

基于SpringBoot框架的企业财务管理系统设计与实现(论文+源码)_kaic

摘 要 在快速增长的信息时代&#xff0c;每个企业都在紧随其后&#xff0c;不断改进其办公模式。与此同时&#xff0c;各家企业的传统管理模式也逐步发生变化&#xff0c;政府和企业都将需要一个更加自动化和现代化的财务管理系统。这能够便利员工之间的信息交流和公司的工作…

day22回溯学习记录第一天- - -代码随想录

77.组合 给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 思路:回溯的经典题目&#xff0c;回溯的整体结构类似于二叉树&#xff0c;如下图所示。据此图可知&#xff0c;采用递归是一种解决方法&#xff0c;在此引入…

定点数的运算

目录 1.定点数的移位运算 1.1算数移位 数学含义&#xff1a; 规律总结&#xff1a; 1.2逻辑移位 1.3循环移位 不带进位位 带进位位 2.定点数的加减运算 3.定点数的乘除运算 3.1原码 一位乘法 除法 3.2补码 一位乘法 除法 1.定点数的移位运算 1.1算数移位 数学…

Java日志框架

笔记学习原视频&#xff08;B站&#xff09;&#xff1a;全面深入学习多种java日志框架 目前常见日志框架有&#xff1a; JULLogbacklog4jlog4j2 目前常见的日志门面&#xff08;统一的日志API&#xff09;&#xff1a; JCLSlf4j 一、 老技术&#xff08;基本都弃用了&…

STM32——外部中断(EXTI)

目录 前言 一、外部中断基础知识 二、使用步骤 三、固件库实现 四、STM32CubeMX实现 总结 前言 外部中断&#xff08;External Interrupt&#xff0c;简称EXTI&#xff09;是微控制器用于响应外部事件的一种方式&#xff0c;当外部事件发生时&#xff08;如按键按下、传感器信号…

软件设计模式概述

模式的诞生 模式(Pattern)起源于建筑业而非软件业 模式之父——美国加利佛尼亚大学环境结构中心研究所所长Christopher Alexander&#xff08;克里斯托弗亚历山大&#xff09;博士 《A Pattern Language: Towns, Buildings, Construction》——253个建筑和城市规划模式。 他给出…