seq2seq翻译实战-Pytorch复现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

code

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import randomimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu

SOS_token = 0
EOS_token = 1# 语言类,方便对语料库进行操作
class Lang:def __init__(self, name):self.name = nameself.word2index = {}self.word2count = {}self.index2word = {0: "SOS", 1: "EOS"}self.n_words    = 2  # Count SOS and EOSdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.n_wordsself.word2count[word] = 1self.index2word[self.n_words] = wordself.n_words += 1else:self.word2count[word] += 1
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')# 小写化,剔除标点与非字母符号
def normalizeString(s):s = unicodeToAscii(s.lower().strip())s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)return s
def readLangs(lang1, lang2, reverse=False):print("Reading lines...")# 以行为单位读取文件lines = open('./data/%s-%s.txt'%(lang1,lang2), encoding='utf-8').\read().strip().split('\n')# 将每一行放入一个列表中# 一个列表中有两个元素,A语言文本与B语言文本pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]# 创建Lang实例,并确认是否反转语言顺序if reverse:pairs       = [list(reversed(p)) for p in pairs]input_lang  = Lang(lang2)output_lang = Lang(lang1)else:input_lang  = Lang(lang1)output_lang = Lang(lang2)return input_lang, output_lang, pairs
MAX_LENGTH = 10      # 定义语料最长长度eng_prefixes = ("i am ", "i m ","he is", "he s ","she is", "she s ","you are", "you re ","we are", "we re ","they are", "they re "
)def filterPair(p):return len(p[0].split(' ')) < MAX_LENGTH and \len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)def filterPairs(pairs):# 选取仅仅包含 eng_prefixes 开头的语料return [pair for pair in pairs if filterPair(pair)]
def prepareData(lang1, lang2, reverse=False):# 读取文件中的数据input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)print("Read %s sentence pairs" % len(pairs))# 按条件选取语料pairs = filterPairs(pairs[:])print("Trimmed to %s sentence pairs" % len(pairs))print("Counting words...")# 将语料保存至相应的语言类for pair in pairs:input_lang.addSentence(pair[0])output_lang.addSentence(pair[1])# 打印语言类的信息    print("Counted words:")print(input_lang.name, input_lang.n_words)print(output_lang.name, output_lang.n_words)return input_lang, output_lang, pairsinput_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

Reading lines…
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words…
Counted words:
fra 4345
eng 2803
[‘elles le font correctement .’, ‘they re doing it right .’]

class EncoderRNN(nn.Module):def __init__(self, input_size, hidden_size):super(EncoderRNN, self).__init__()self.hidden_size = hidden_sizeself.embedding   = nn.Embedding(input_size, hidden_size)self.gru         = nn.GRU(hidden_size, hidden_size)def forward(self, input, hidden):embedded       = self.embedding(input).view(1, 1, -1)output         = embeddedoutput, hidden = self.gru(output, hidden)return output, hiddendef initHidden(self):return torch.zeros(1, 1, self.hidden_size, device=device)
class DecoderRNN(nn.Module):def __init__(self, hidden_size, output_size):super(DecoderRNN, self).__init__()self.hidden_size = hidden_sizeself.embedding   = nn.Embedding(output_size, hidden_size)self.gru         = nn.GRU(hidden_size, hidden_size)self.out         = nn.Linear(hidden_size, output_size)self.softmax     = nn.LogSoftmax(dim=1)def forward(self, input, hidden):output         = self.embedding(input).view(1, 1, -1)output         = F.relu(output)output, hidden = self.gru(output, hidden)output         = self.softmax(self.out(output[0]))return output, hiddendef initHidden(self):return torch.zeros(1, 1, self.hidden_size, device=device)
# 将文本数字化,获取词汇index
def indexesFromSentence(lang, sentence):return [lang.word2index[word] for word in sentence.split(' ')]# 将数字化的文本,转化为tensor数据
def tensorFromSentence(lang, sentence):indexes = indexesFromSentence(lang, sentence)indexes.append(EOS_token)return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)# 输入pair文本,输出预处理好的数据
def tensorsFromPair(pair):input_tensor  = tensorFromSentence(input_lang, pair[0])target_tensor = tensorFromSentence(output_lang, pair[1])return (input_tensor, target_tensor)
teacher_forcing_ratio = 0.5def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):# 编码器初始化encoder_hidden = encoder.initHidden()# grad属性归零encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_length  = input_tensor.size(0)target_length = target_tensor.size(0)# 用于创建一个指定大小的全零张量(tensor),用作默认编码器输出encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# 将处理好的语料送入编码器for ei in range(input_length):encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)encoder_outputs[ei]            = encoder_output[0, 0]# 解码器默认输出decoder_input  = torch.tensor([[SOS_token]], device=device)decoder_hidden = encoder_hiddenuse_teacher_forcing = True if random.random() < teacher_forcing_ratio else False# 将编码器处理好的输出送入解码器if use_teacher_forcing:# Teacher forcing: Feed the target as the next inputfor di in range(target_length):decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)loss         += criterion(decoder_output, target_tensor[di])decoder_input = target_tensor[di]  # Teacher forcingelse:# Without teacher forcing: use its own predictions as the next inputfor di in range(target_length):decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)topv, topi    = decoder_output.topk(1)decoder_input = topi.squeeze().detach()  # detach from history as inputloss         += criterion(decoder_output, target_tensor[di])if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_length
import time
import mathdef asMinutes(s):m = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)def timeSince(since, percent):now = time.time()s = now - sincees = s / (percent)rs = es - sreturn '%s (- %s)' % (asMinutes(s), asMinutes(rs))
def trainIters(encoder,decoder,n_iters,print_every=1000,plot_every=100,learning_rate=0.01):start = time.time()plot_losses      = []print_loss_total = 0  # Reset every print_everyplot_loss_total  = 0  # Reset every plot_everyencoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)# 在 pairs 中随机选取 n_iters 条数据用作训练集training_pairs    = [tensorsFromPair(random.choice(pairs)) for i in range(n_iters)]criterion         = nn.NLLLoss()for iter in range(1, n_iters + 1):training_pair = training_pairs[iter - 1]input_tensor  = training_pair[0]target_tensor = training_pair[1]loss = train(input_tensor, target_tensor, encoder,decoder, encoder_optimizer, decoder_optimizer, criterion)print_loss_total += lossplot_loss_total  += lossif iter % print_every == 0:print_loss_avg   = print_loss_total / print_everyprint_loss_total = 0print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),iter, iter / n_iters * 100, print_loss_avg))if iter % plot_every == 0:plot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0return plot_losses
hidden_size   = 256
encoder1      = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)plot_losses = trainIters(encoder1, attn_decoder1, 20000, print_every=5000)
7m 2s (- 21m 6s) (5000 25%) 2.8981
13m 59s (- 13m 59s) (10000 50%) 2.3636
21m 3s (- 7m 1s) (15000 75%) 2.0134
28m 10s (- 0m 0s) (20000 100%) 1.7973
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               # 忽略警告信息
# plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        # 分辨率epochs_range = range(len(plot_losses))plt.figure(figsize=(8, 3))plt.subplot(1, 1, 1)
plt.plot(epochs_range, plot_losses, label='Training Loss')
plt.legend(loc='upper right')
plt.title('Training Loss')
plt.show()

在这里插入图片描述

总结

构建基于PyTorch的seq2seq翻译系统是一个综合性的过程,它首先涉及数据的收集与预处理,包括将源语言和目标语言的文本对转换为适合模型训练的格式,并构建相应的词汇表。随后,定义包含编码器(负责将源语言序列编码为上下文信息)和解码器(利用编码信息生成目标语言序列)的seq2seq模型架构。在训练阶段,通过迭代训练数据,优化模型参数以最小化翻译损失,如交叉熵损失,同时采用正则化技术和梯度裁剪来防止过拟合和梯度爆炸。最后,在独立的测试集上评估模型的翻译性能,通过计算如BLEU分数等指标来衡量其准确性和流畅度。这一过程不仅要求深入理解seq2seq模型和PyTorch框架,还需要细致的数据处理和模型调优策略,以构建出高效且准确的翻译系统。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/428190.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

半导体器件制造5G智能工厂数字孪生物联平台,推进制造业数字化转型

半导体器件制造行业作为高科技领域的核心驱动力&#xff0c;正积极探索和实践以5G智能工厂数字孪生平台为核心的新型制造模式。这一创新不仅极大地提升了生产效率与质量&#xff0c;更为制造业的未来发展绘制了一幅智能化、网络化的宏伟蓝图。 在半导体器件制造5G智能工厂中&a…

RabbitMQ 高级特性——发送方确认

文章目录 前言发送方确认confirm 确认模式return 退回模式 常见面试题 前言 前面我们学习了 RabbitMQ 中交换机、队列和消息的持久化&#xff0c;这样能够保证存储在 RabbitMQ Broker 中的交换机和队列中的消息实现持久化&#xff0c;就算 RabbitMQ 服务发生了重启或者是宕机&…

中间件知识点-消息中间件(Rabbitmq)一

消息中间件介绍 MQ的作用(优点)主要有以下三个方面&#xff1a; a.异步 b.解耦 c.削峰 MQ的作用(缺点)主要有以下三个方面&#xff1a; a.系统可用性降低 b.系统复杂度提高 c.存在消息一致性问题需要解决 备注&#xff1a; 引入MQ后系统的复杂度会大大提高。 以前服务之间可以…

【软件基础知识】什么是 API,详细解读

想象一下,你正在使用智能手机上的天气应用。你打开应用,瞬间就能看到实时天气、未来预报,甚至是空气质量指数。但你有没有想过,这些数据是如何神奇地出现在你的屏幕上的?答案就在三个字母中:API。 API,全称Application Programming Interface(应用程序编程接口),是现代软件世…

计算机网络 --- 初识协议

序言 上一篇文章中 &#xff08;&#x1f449;点击查看&#xff09;&#xff0c;我们简单的了解了怎么寻找目标计算机&#xff0c;需要通过交换机&#xff0c;路由器等设备跨越多个网络来不断的转发我们需要传输的数据&#xff0c;直至到达目标计算机。  那我们设备之间数据是…

重回极简:华为如何走向全面智能化?

“人类发现地球只是宇宙一员的时候&#xff0c;也是我们距离群星最遥远的时候。” 这个来自天文领域的喟叹&#xff0c;今天同样出现在行业与企业的智能化之路上。在这个时代坐标上&#xff0c;AI大模型技术极速成熟&#xff0c;AIGC和AI Agent等应用受到了各个行业的巨大期待。…

昇腾大模型推理解决方案MindIE部署

MindIE大模型推理套件 MindIE&#xff08;Mind Inference Engine&#xff0c;昇腾推理引擎&#xff09;是华为公司针对AI全场景推出的整体解决方案&#xff0c;包含丰富的推理加速套件。通过开放各层次AI能力&#xff0c;支撑客户多样化的AI业务需求&#xff0c;使能百模千态&a…

存储 NFS

目录 1.存储的应用场景 2.存储分类 3.NFS服务组成 4.环境说明 ​编辑 5.服务端部署 6.NFS服务端的配置 7.NFS服务端本地进行测试 1.存储的应用场景 存储一般用于上传网站数据&#xff08;内容&#xff09;&#xff0c;一般用于在网站集群中。使用存储的话用户上传的…

成型的程序

加一个提示信息 加上python 常用的包 整个程序打包完 250M 安装 960MB matplot numpy pandas scapy pysearial 常用的包 (pyvisa)… … 啥都有 Python 解释器组件构建 要比 lua 容易的多 &#xff08;C/Rust 的组件库)

JavaSE--集合总览02:单列集合Collection的体系之一:List

Collection体系的特点 分为 list 和set集合&#xff0c;这篇文章主要讲述List&#xff0c;下篇讲述Set。 简单认识单列集合collection集合的特点 : list集合的特点&#xff1a; 有序 可重复 有索引 set集合的特点&#xff1a;无序 不重复 无索引 其中LinkedHashSet有序 TreeS…

微服务架构陷阱与挑战

微服务架构6大陷阱 现在微服务的基础设施还是越来越完善了&#xff0c;现在基础设施缺乏的问题逐渐被解决了。 拆分粒度太细&#xff0c;服务关系复杂 拆分降低了服务的内部复杂度&#xff0c;但是提升了系统的外部复杂度&#xff0c;服务越多&#xff0c;服务和服务之间的连接…

from tqdm.auto import tqdm用法详细介绍

tqdm 是一个 Python 库&#xff0c;用于在长时间运行的任务中显示进度条。tqdm.auto 是 tqdm 的一个版本&#xff0c;能够自动适配输出环境&#xff08;如 Jupyter Notebook、命令行等&#xff09;&#xff0c;以确保进度条在各种环境下显示正确。下面是 tqdm.auto 的详细用法介…

每天五分钟计算机视觉:将人脸识别问题转换为二分类问题

本文重点 在前面的课程中,我们学习了两种人脸识别的网络模型,这两种人脸识别网络不能算是基于距离或者Triplet loss等等完成的神经网络参数的学习。我们比较熟悉的是分类任务,那么人脸识别是否可以转变为分类任务呢? 本节课程我们将介绍一种全新的方法来学习神经网络的参…

用友U8二次开发工具KK-FULL-*****-EFWeb使用方法

1、安装: 下一步&#xff0c;下一步即可。弹出黑框不要关闭&#xff0c;让其自动执行并关闭。 2、服务配置&#xff1a; 输入服务器IP地址&#xff0c;选择U8数据源&#xff0c;输入U8用户名及账号&#xff0c;U8登录日期勾选系统日期。测试参数有效性&#xff0c;提示测试通过…

利用 FastAPI 和 Jinja2 模板引擎快速构建 Web 应用

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;它基于标准 Python 类型提示。FastAPI 支持异步编程&#xff0c;使得开发高性能的 Web 应用变得简单快捷。在本文中&#xff0c;我们将探讨如何使用 FastAPI 结合 Jinj…

VS code 查看 ${workspaceFolder} 目录指代路径

VS code 查看 ${workspaceFolder} 目录指代路径 引言正文 引言 在 VS code 创建与运行 task.json 文件 一文中我们已经介绍了如何创建属于自己的 .json 文件。在 VS code 中&#xff0c;有时候我们需要添加一些文件路径供我们导入自定义包使用&#xff0c;此时&#xff0c;我们…

OpenCV系列教程二:基本图像增强(数值运算)、滤波器(去噪、边缘检测)

文章目录 一、基本图像增强&#xff08;数值运算&#xff09;1.1 加法 &#xff08;cv2.add&#xff09;1.1.1 图像与标量相加&#xff08;调节亮度&#xff09;1.1.2 图像与图像相加&#xff08;两个图像shape要相同&#xff09;1.1.3 图像的加权加法&#xff08;渐变切换&…

基于SpringBoot+Vue+MySQL的网上租赁系统

系统展示 用户前台界面 管理员后台界面 系统背景 在当前共享经济蓬勃发展的背景下&#xff0c;网上租赁系统作为连接租赁双方的重要平台&#xff0c;正逐步改变着人们的消费观念和生活方式。通过构建一个基于SpringBoot、Vue.js与MySQL的网上租赁系统&#xff0c;我们旨在为用户…

python线程(python threading模块、python多线程)(守护线程与非守护线程)

文章目录 Python多线程入门1. Python多线程概述2. threading模块基础- Thread 类: 这是一个代表线程的类。可以通过创建Thread类的实例来新建一个线程。- Lock 类: 在多线程环境中&#xff0c;为了防止数据错乱&#xff0c;通常需要用到锁机制。Lock类提供了基本的锁功能&#…

本地搭建我的世界服务器(JAVA)简单记录

网上参考教程挺多的&#xff0c;踩了不少坑&#xff0c;简单记录一下&#xff0c;我做的是一个私人服务器&#xff0c;就是和朋友3、4个人玩。 笨蛋 MC 开服教程 先放一个比较系统和完整的教程&#xff0c;萌新可用&#xff0c;这个教程很详细&#xff0c;我只是记录一下自己的…