机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430985.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SpringBoot+Vue+MySQL的美食点餐管理系统

系统展示 用户前台界面 管理员后台界面 系统背景 在数字化快速发展的今天,餐饮行业也迎来了转型升级的重要机遇。传统餐饮管理方式面临效率低下、顾客体验不佳等问题。为此,开发一款基于SpringBootVueMySQL架构的美食点餐管理系统显得尤为重要。该系统旨…

【可图(Kolors)部署与使用】大规模文本到图像生成模型部署与使用教程

✨ Blog’s 主页: 白乐天_ξ( ✿>◡❛) 🌈 个人Motto:他强任他强,清风拂山冈! 💫 欢迎来到我的学习笔记! 1.Kolors 简介 1.1.什么是Kolors? 开发团队 Kolors 是由快手 Kolors 团队…

网页护眼宝——全方位解析 Chrome Dark Reader 插件

网页护眼宝——全方位解析 Chrome Dark Reader 插件 1. 基本介绍:Chrome 插件的力量与 Dark Reader 的独特之处 随着现代浏览器的功能越来越强大,Chrome 插件为用户提供了极大的定制化能力。从广告屏蔽、性能优化到页面翻译,Chrome 插件几乎…

视频监控相关笔记

一、QT 之 QTreeWidget 树形控件 Qt编程指南,Qt新手教程,Qt Programming Guide 一个树形结构的节点中的图表文本 、附带数据的添加: QTreeWidgetItem* TourTreeWnd::InsertNode(NetNodeInfo node, QTreeWidgetItem* parent_item) { // …

C++: unordered系列关联式容器

目录 1. unordered系列关联式容器1.1 unordered_map1.2 unordered_set 2. 哈希概念3. 哈希冲突4. 闭散列5. 开散列 博客主页: 酷酷学 感谢关注!!! 正文开始 1. unordered系列关联式容器 在C98中,STL提供了底层为红黑树结构的一系列关联式容器,在查询时…

2024 天池云原生编程挑战赛决赛名单出炉,冠军来自中山大学、昆仑数智战队

9 月 20 日,2024 天池云原生编程挑战赛决赛答辩完美落幕,12 支进入决赛的团队用精彩的答辩,为历时 3 个月的大赛画下了圆满的句号。其中,来自中山大学的陈泓仰以及来自昆仑数智的冉旭欣、沈鑫糠、武鹏鹏, 以出色的方案…

[深度学习]神经网络

1 人工神经网络 全连接神经网络 2 激活函数 隐藏层激活函数由人决定输出层激活函数由解决的任务决定: 二分类:sigmoid多分类:softmax回归:不加激活(恒等激活identify)2.1 sigmoid激活函数 x为加权和小于-6或者大于6,梯度接近于0,会出现梯度消失的问题即使取值 [-6,6] ,…

乌克兰因安全风险首次禁用Telegram

据BleepingComputer消息,乌克兰国家网络安全协调中心 (NCCC) 以国家安全为由,已下令限制在政府机构、军事单位和关键基础设施内使用 Telegram 消息应用程序。 这一消息通过NCCC的官方 Facebook 账号对外发布,在公告中乌…

kubernetes网络(二)之bird实现节点间BGP互联的实验

摘要 上一篇文章中我们学习了calico的原理,kubernetes中的node节点,利用 calico 的 bird 程序相互学习路由,为了加深对 bird 程序的认识,本文我们将使用bird进行实验,实验中实现了BGP FULL MESH模式让宿主相互学习到对…

AI大模型日报#0923:李飞飞创业之后首个专访、华为云+腾讯音乐发布昇腾适配方案

导读:AI大模型日报,爬虫LLM自动生成,一文览尽每日AI大模型要点资讯!目前采用“文心一言”(ERNIE-4.0-8K-latest)、“智谱AI”(glm-4-0520)生成了今日要点以及每条资讯的摘要。欢迎阅…

深兰科技陈海波应邀出席2024长三角论坛暨虹桥人才创新发展大会

近日,以“人才引领 联动共融——国际化创新与长三角协同”为主题的“2024长三角人才发展论坛暨虹桥人才创新发展大会”在上海国际会议中心隆重举行。上海市委常委、组织部部长、市委人才办主任张为应邀出席并做大会致辞。 深兰科技创始人、董事长陈海波作为特邀企业…

数据结构强化(直播课)

应用题真题分析&备考指南 (三)线性表的应用 (六)栈、队列和数组的应用 (四)树与二叉树的应用 1.哈夫曼(Huffman)树和哈夫曼编码 2.并查集及其应用(重要) (四)图的基本应用 …

计算机组成原理(笔记4)

定点加减法运算 补码加法&#xff1a; 补码减法&#xff1a; 求补公式&#xff1a; 溢出的概念 在定点小数机器中,数的表示范围为|&#xff58;|<1。在运算过程中如出现大于1的现象,称为 “溢出”。 上溢&#xff1a;两个正数相加&#xff0c;结果大于机器所能表示的最…

【算法】堆与优先级队列

【ps】本篇有 4 道 leetcode OJ。 目录 一、算法简介 二、相关例题 1&#xff09;最后一块石头的重量 .1- 题目解析 .2- 代码编写 2&#xff09;数据流中的第 K 大元素 .1- 题目解析 .2- 代码编写 3&#xff09;前K个高频单词 .1- 题目解析 .2- 代码编写 4&#xf…

d2l | 目标检测数据集:RuntimeError: No such operator image::read_file

目录 1 存在的问题2 可能的解决方案3 最终的解决方案3.1 方案一&#xff08;我已弃用&#xff09;3.2 方案二&#xff08;基于方案一&#xff09; 1 存在的问题 李沐老师提供的读取香蕉数据集的函数如下&#xff1a; def read_data_bananas(is_trainTrue):""…

yolov10算法原理

文章目录 1. 模型效果2. 模型特点2.1 无NMS训练的一致性双重分配策略 (Consistent Dual Assignments for NMS-free Training)双重标签分配 (Dual Label Assignments)一致匹配度量&#xff08;Consistent Match. Metric&#xff09;一对一分配在一对多结果中的频率 2.2. 效率-准…

C++基础:第一个C++程序

初学C #include<iostream> int main() {std::cout << "Enter two numbers:" << std::endl;int v1 0, v2 0;std::cin >> v1 >> v2;std::cout << "The sum of "<< v1 << " and " << v2&…

Ubuntu磁盘不足扩容

1.问题 Ubuntu磁盘不足扩容 2.解决方法 安装一下 sudo apt-get install gpartedsudo gparted

JavaWeb--小白笔记07:servlet对表单数据的简单处理

这里的servlet对表单数据的处理是指使用IDEA创建web工程&#xff0c;再创建html和class文件进行连接&#xff0c;实现html创建一个表单网页&#xff0c;我们对网页中的表单进行填充&#xff0c;可以通过class文件得到网页我们填充的内容进行打印到控制台。 一登录系统页面---h…

【速成Redis】04 Redis 概念扫盲:事务、持久化、主从复制、哨兵模式

前言&#xff1a; 前三篇如下&#xff1a; 【速成Redis】01 Redis简介及windows上如何安装redis-CSDN博客 【速成Redis】02 Redis 五大基本数据类型常用命令-CSDN博客 【速成Redis】03 Redis 五大高级数据结构介绍及其常用命令 | 消息队列、地理空间、HyperLogLog、BitMap、…