PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型预测)

文章目录

  • 引言
  • 数据预处理
    • 加载字典对象`en2id`和`zh2id`
    • 文本分词
  • 加载训练好的Seq2Seq模型
  • 模型预测完整代码
  • 结束语

引言

随着全球化的深入,翻译需求日益增长。传统的人工翻译方式虽然质量高,但效率低,成本高。机器翻译的出现,为解决这一问题提供了可能。英译中机器翻译任务是机器翻译领域的一个重要分支,旨在将英文文本自动翻译成中文。本博客以《PyTorch自然语言处理入门与实战》第九章的Seq2seq模型处理英译中翻译任务作为基础,附上模型预测模块。

模型的训练及验证模块的详细解析见PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型训练及验证)

数据预处理

加载字典对象en2idzh2id

在预测阶段中,需要加载模型训练及验证阶段保存的字典对象en2idzh2id

代码如下:

import picklewith open("en2id.pkl", 'rb') as f:en2id = pickle.load(f)
with open("zh2id.pkl", 'rb') as f:zh2id = pickle.load(f)

文本分词

在对输入文本进行预测时,需要先将文本进行分词操作。参考代码如下:

def extract_words(sentence):  """  从给定的英文句子中提取单词,并去除单词后的标点符号。  Args:  sentence (str): 要提取单词的英文句子。  Returns:  List[str]: 提取并处理后的单词列表。  """  en_words = []  for w in sentence.split(' '):  # 将英文句子按空格分词  w = w.replace('.', '').replace(',', '')  # 去除跟单词连着的标点符号  w = w.lower()  # 统一单词大小写  if w:  en_words.append(w)  return en_words  # 测试函数  
sentence = 'I am Dave Gallo.'  
print(extract_words(sentence))

运行结果:

加载训练好的Seq2Seq模型

代码如下:

import torch
import torch.nn as nnclass Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.dropout = nn.Dropout(dropout)def forward(self, src):# src = (src len, batch size)embedded = self.dropout(self.embedding(src))# embedded = (src len, batch size, emb dim)outputs, (hidden, cell) = self.rnn(embedded)# outputs = (src len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# rnn的输出总是来自顶部的隐藏层return hidden, cellclass Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.output_dim = output_dimself.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.fc_out = nn.Linear(hid_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, cell):# 各输入的形状# input = (batch size)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# LSTM是单向的  ==> n directions == 1# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)output, (hidden, cell) = self.rnn(embedded, (hidden, cell))# LSTM理论上的输出形状# output = (seq len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# 解码器中的序列长度 seq len == 1# 解码器的LSTM是单向的 n directions == 1 则实际上# output = (1, batch size, hid dim)# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)prediction = self.fc_out(output.squeeze(0))# prediction = (batch size, output dim)return prediction, hidden, cellclass Seq2Seq(nn.Module):def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,encode_dropout, decode_dropout, device):""":param input_word_count:    英文词表的长度     34737:param output_word_count:   中文词表的长度     4015:param encode_dim:          编码器的词嵌入维度:param decode_dim:          解码器的词嵌入维度:param hidden_dim:          LSTM的隐藏层维度:param n_layers:            采用n层LSTM:param encode_dropout:      编码器的dropout概率:param decode_dropout:      编码器的dropout概率:param device:              cuda / cpu"""super().__init__()self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)self.device = devicedef forward(self, src):# src = (src len, batch size)# 编码器的隐藏层输出将作为解码器的第一个隐藏层输入hidden, cell = self.encoder(src)# 解码器的第一个输入应该是起始标识符<sos>input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引pred = [0] # 预测的第一个输出应该是起始标识符top1 = 0while top1 != 1 and len(pred) < 100:# 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states# 解码器的输出包括:输出张量(predictions) and new hidden and cell statesoutput, hidden, cell = self.decoder(input, hidden, cell)top1 = output.argmax(dim=1)  # (batch size, )pred.append(top1.item())input = top1return preddevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
# Seq2Seq模型实例化
source_word_count = 34737  # 英文词表的长度     34737
target_word_count = 4015  # 中文词表的长度     4015
encode_dim = 256  # 编码器的词嵌入维度
decode_dim = 256  # 解码器的词嵌入维度
hidden_dim = 512  # LSTM的隐藏层维度
n_layers = 2  # 采用n层LSTM
encode_dropout = 0.5  # 编码器的dropout概率
decode_dropout = 0.5  # 编码器的dropout概率
model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,decode_dropout, device).to(device)# 加载训练好的模型
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

模型预测完整代码

提示预测代码是我们基于训练及验证代码进行改造的,不一定完全正确,可以参考后自行修改~

import torch
import torch.nn as nn
import pickleclass Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.dropout = nn.Dropout(dropout)def forward(self, src):# src = (src len, batch size)embedded = self.dropout(self.embedding(src))# embedded = (src len, batch size, emb dim)outputs, (hidden, cell) = self.rnn(embedded)# outputs = (src len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# rnn的输出总是来自顶部的隐藏层return hidden, cellclass Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.output_dim = output_dimself.hid_dim = hid_dimself.n_layers = n_layersself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)self.fc_out = nn.Linear(hid_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, cell):# 各输入的形状# input = (batch size)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# LSTM是单向的  ==> n directions == 1# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)output, (hidden, cell) = self.rnn(embedded, (hidden, cell))# LSTM理论上的输出形状# output = (seq len, batch size, hid dim * n directions)# hidden = (n layers * n directions, batch size, hid dim)# cell = (n layers * n directions, batch size, hid dim)# 解码器中的序列长度 seq len == 1# 解码器的LSTM是单向的 n directions == 1 则实际上# output = (1, batch size, hid dim)# hidden = (n layers, batch size, hid dim)# cell = (n layers, batch size, hid dim)prediction = self.fc_out(output.squeeze(0))# prediction = (batch size, output dim)return prediction, hidden, cellclass Seq2Seq(nn.Module):def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,encode_dropout, decode_dropout, device):""":param input_word_count:    英文词表的长度     34737:param output_word_count:   中文词表的长度     4015:param encode_dim:          编码器的词嵌入维度:param decode_dim:          解码器的词嵌入维度:param hidden_dim:          LSTM的隐藏层维度:param n_layers:            采用n层LSTM:param encode_dropout:      编码器的dropout概率:param decode_dropout:      编码器的dropout概率:param device:              cuda / cpu"""super().__init__()self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)self.device = devicedef forward(self, src):# src = (src len, batch size)# 编码器的隐藏层输出将作为解码器的第一个隐藏层输入hidden, cell = self.encoder(src)# 解码器的第一个输入应该是起始标识符<sos>input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引pred = [0] # 预测的第一个输出应该是起始标识符top1 = 0while top1 != 1 and len(pred) < 100:# 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states# 解码器的输出包括:输出张量(predictions) and new hidden and cell statesoutput, hidden, cell = self.decoder(input, hidden, cell)top1 = output.argmax(dim=1)  # (batch size, )pred.append(top1.item())input = top1return predif __name__ == '__main__':sentence = 'I am Dave Gallo.'en_words = []for w in sentence.split(' '):  # 英文内容按照空格字符进行分词# 按照空格进行分词后,某些单词后面会跟着标点符号 "." 和 “,”w = w.replace('.', '').replace(',', '')  # 去掉跟单词连着的标点符号w = w.lower()  # 统一单词大小写if w:en_words.append(w)print(en_words)with open("en2id.pkl", 'rb') as f:en2id = pickle.load(f)with open("zh2id.pkl", 'rb') as f:zh2id = pickle.load(f)device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU# Seq2Seq模型实例化source_word_count = 34737  # 英文词表的长度     34737target_word_count = 4015  # 中文词表的长度     4015encode_dim = 256  # 编码器的词嵌入维度decode_dim = 256  # 解码器的词嵌入维度hidden_dim = 512  # LSTM的隐藏层维度n_layers = 2  # 采用n层LSTMencode_dropout = 0.5  # 编码器的dropout概率decode_dropout = 0.5  # 编码器的dropout概率model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,decode_dropout, device).to(device)model.load_state_dict(torch.load("best_model.pth"))model.eval()src = [0] # 0 --> 起始标识符的编码for i in range(len(en_words)):src.append(en2id[en_words[i]])src = src + [1] # 1 --> 终止标识符的编码text_input = torch.LongTensor(src)text_input = text_input.unsqueeze(-1).to(device)text_output = model(text_input)print(text_output)id2zh = dict()for k, v in zh2id.items():id2zh[v] = ktext_output = [id2zh[index] for index in text_output]text_output = " ".join(text_output)print(text_output)

结束语

  • 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。
  • 您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。
  • 无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!
  • 我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。
  • 谢谢您的阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/225866.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚函数的讲解

文章目录 虚函数的声明与定义代码演示基类Person派生类Man派生类Woman 测试代码动态绑定静态绑定访问私有虚函数总结一下通过成员函数指针调用函数的方式 虚函数的声明与定义 虚函数存在于C的类、结构体等中&#xff0c;不能存在于全局函数中&#xff0c;只能作为成员函数存在…

IntelliJ IDEA [插件 MybatisX] mapper和xml间跳转

文章目录 1. 安装插件2. 如何使用3. 主要功能总结 MybatisX 是一款为 IntelliJ IDEA 提供支持的 MyBatis 开发插件 它通过提供丰富的功能集&#xff0c;大大简化了 MyBatis XML 文件的编写、映射关系的可视化查看以及 SQL 语句的调试等操作。本文将介绍如何安装、配置和使用 In…

知识库问答LangChain+LLM的二次开发:商用时的典型问题及其改进方案

前言 如之前的文章所述&#xff0c;我司下半年成立大模型项目团队之后&#xff0c;我虽兼管整个项目团队&#xff0c;但为让项目的推进效率更高&#xff0c;故分成了三大项目组 第一项目组由霍哥带头负责类似AIGC模特生成系统第二项目组由阿荀带头负责论文审稿GPT以及AI agen…

基于飞浆OCR的文本框box及坐标中心点检测JSON格式保存文本

OCR的文本框box及JSON数据保存 需求说明 一、借助飞浆框出OCR识别的文本框 二、以圆圈形式标出每个框的中心点位置 三、以JSON及文本格式保存OCR识别的文本 四、以文本格式保存必要的文本信息 解决方法 一、文本的坐标来自飞浆的COR识别 二、借助paddleocr的draw_ocr画出…

go语言,ent库与gorm库,插入一条null值的time数据

情景介绍 使用go语言&#xff0c;我需要保存xxxTime的字段至数据库中&#xff0c;这个字段可能为空&#xff0c;也可能是一段时间。我采取的是统一先赋值为空&#xff0c;若有需要&#xff0c;则再进行插入&#xff08;需要根据另一个字段判断是否插入&#xff09; 在我的数据…

最新国内使用GPT4教程,GPT语音对话使用,Midjourney绘画,ChatFile文档对话总结+DALL-E3文生图

一、前言 ChatGPT3.5、GPT4.0、GPT语音对话、Midjourney绘画&#xff0c;文档对话总结DALL-E3文生图&#xff0c;相信对大家应该不感到陌生吧&#xff1f;简单来说&#xff0c;GPT-4技术比之前的GPT-3.5相对来说更加智能&#xff0c;会根据用户的要求生成多种内容甚至也可以和…

HPCC:高精度拥塞控制

HPCC&#xff1a;高精度拥塞控制 文章目录 HPCC&#xff1a;高精度拥塞控制摘要1 引言1.1 背景1.2 现有CC的局限性1.3 HPCC的提出 2 研究动机2.1 大型RDMA部署2.2 RDMA目标2.3 当前RDMA CC中的权衡DCQCNTIMELY 2.4 下一代高速CC 3 技术方案3.1 INT3.2 HPCC设计3.3 HPPC的参数 4…

浅谈WPF之ToolTip工具提示

在日常应用中&#xff0c;当鼠标放置在某些控件上时&#xff0c;都会有相应的信息提示&#xff0c;从软件易用性上来说&#xff0c;这是一个非常友好的功能设计。那在WPF中&#xff0c;如何进行控件信息提示呢&#xff1f;这就是本文需要介绍的ToolTip【工具提示】内容&#xf…

数据结构入门到入土——List的介绍

目录 一&#xff0c;什么是List&#xff1f; 二&#xff0c;常见接口介绍 三&#xff0c;List的使用 一&#xff0c;什么是List&#xff1f; 在集合框架中&#xff0c;List是一个接口&#xff0c;继承自Collection。 Collection也是一个接口&#xff0c;该接口中规范了后序容…

MATLAB中./和/,.*和*,.^和^的区别

MATLAB中./和/&#xff0c;.*和*&#xff0c;.^ 和^ 的区别 MATLAB中./和/&#xff0c;.*和*&#xff0c;.^ 和^ 的区别./ 和 / 的区别.//实验实验结果 .* 和 * 的区别.**实验实验结果 .^ 和^ 的区别.^n^n实验运行结果 MATLAB中./和/&#xff0c;.和&#xff0c;.^ 和^ 的区别 …

关于SQL时间盲注(基于sleep函数)的手动测试、burpsuite爆破、sqlmap全自动化注入

SQL时间注入是一种常见的SQL注入攻击方式&#xff0c;攻击者通过在SQL语句中注入时间相关的代码&#xff0c;来获取敏感信息或者执行非法操作。其基本原理如下&#xff1a; 攻击者向Web应用程序中输入一段恶意代码&#xff0c;通过SQL语句查询数据库&#xff0c;并注入时间相关…

钉钉机器人接入定时器(钉钉API+XXL-JOB)

钉钉机器人接入定时器&#xff08;钉钉APIXXL-JOB&#xff09; 首先需要创建钉钉内部群 在群设置中找到机器人选项 选择“自定义”机器人 通过Webhook接入自定义服务 创建完成后会生成一个send URL和一个加签码 下面就是干货 代码部分了 DingDingUtil.sendMessageByText(webho…

什么是迁移学习(Transfer Learning)?定义,优势,方法

迄今为止&#xff0c;大多数人工智能&#xff08;AI&#xff09;项目都是通过监督学习技术构建的。监督学习是一种从无到有构建机器学习&#xff08;ML&#xff09;模型的方法&#xff0c;它对推动AI发展起到了关键作用。然而&#xff0c;由于需要大量的数据集和强大的计算能力…

账号租号平台PHP源码,支持单独租用或合租使用

源码简介 租号平台源码&#xff0c;采用常见的租号模式。 平台的主要功能如下&#xff1a; 支持单独租用或采用合租模式&#xff1b; 采用易支付通用接口进行支付&#xff1b; 添加邀请返利功能&#xff0c;以便站长更好地推广&#xff1b; 提供用户提现功能&#xff1b;…

PHP的Laravel加一个小页面出现问题(whereRaw的用法)

1.权限更新问题 因为是已经有样例了所以html和php页面很快写出来了 然后就是页面写完了路由不知道在哪写&#xff0c;后来想起来之前有要开权限来着&#xff0c;试了一下&#xff0c;还是不行&#xff0c;不过方向是对了 这是加的路由&#xff0c;不过需要在更新一下权限 这…

【产品经理】axure中继器的使用——表格增删改查分页实现

笔记为个人总结笔记&#xff0c;若有错误欢迎指出哟~ axure中继器的使用——表格增删改查分页实现 中继器介绍总体视图视频预览功能1.表头设计2.中继器3.添加功能实现4.删除功能实现5.修改功能实现6.查询功能实现7.批量删除 中继器介绍 在 Axure RP9 中&#xff0c;中继器&…

leetcode贪心算法题总结(一)

此系列分三章来记录leetcode的有关贪心算法题解&#xff0c;题目我都会给出具体实现代码&#xff0c;如果看不懂的可以后台私信我。 本章目录 1.柠檬水找零2.将数组和减半的最少操作次数3.最大数4.摆动序列5.最长递增子序列6.递增的三元子序列7.最长连续递增序列8.买卖股票的最…

C++:map和set的介绍及使用

目录 1. 关联式容器 2. 键值对 3. 树形结构的关联式容器 3.1 set 3.1.1 set的介绍 3.1.2 set的使用 3.2 map 3.2.1 map的介绍 3.2.2 map的使用 3.3 multiset 3.3.1 multiset的介绍 3.3.2 multiset的使用 3.4 multimap 3.4.1 multimap的介绍 3.4.2 multimap的使用…

Linux磁盘与文件系统管理

在linux系统中使用硬盘 建立分区 安装文件系统 挂载 磁盘的数据结构 磁盘&#xff1a;扇区固定大小&#xff0c;每个扇区4k。磁盘会进行磨损&#xff0c;损失生命周期。 扇区 磁道 柱面 磁盘接口类型 ide SATA SAS SCSI SCSI 设备类型 块设备&#xff1a;block …

kubelet源码学习(二):kubelet创建Pod流程

本文基于Kubernetes v1.22.4版本进行源码学习 4、kubelet创建Pod流程 syncLoop()的主要逻辑是在syncLoopIteration()方法中实现&#xff0c;Pod创建相关代码只需要看处理configCh部分的代码 // pkg/kubelet/kubelet.go // 该方法会监听多个channel,当发现任何一个channel有数…