Pytorch实现多层LSTM模型,并增加emdedding、Dropout、权重共享等优化

简述

本文是 Pytorch封装简单RNN模型,进行中文训练及文本预测 一文的延申,主要做以下改动:

1.将nn.RNN替换为nn.LSTM,并设置多层LSTM:

既然使用pytorch了,自然不需要手动实现多层,注意nn.RNNnn.LSTM 在实例化时均有参数num_layers来指定层数,本文设置num_layers=2

2.新增emdedding层,替换掉原来的nn.functional.one_hot向量化,这样得到的emdedding层可以用来做词向量分布式表示;

3.在emdedding后、LSTM内部、LSTM后均增加Dropout层,来抑制过拟合:

nn.LSTM内部的Dropout可以通过实例化时的参数dropout来设置,需要注意pytorch仅在两层lstm之间应用Dropout,不会在最后一层的LSTM输出上应用Dropout

emdedding后、LSTM后与线性层之间则需要手动添加Dropout层。

4.考虑emdedding与最后的Linear层共享权重:

这样做可以在保证精度的情况下,减少学习参数,但本文代码没有实现该部分。

不考虑第四条时,模型结构如下:

在这里插入图片描述

代码

模型代码:

class MyLSTM(nn.Module):  def __init__(self, vocab_size, wordvec_size, hidden_size, num_layers=2, dropout=0.5):  super(MyLSTM, self).__init__()  self.vocab_size = vocab_size  self.word_vec_size = wordvec_size  self.hidden_size = hidden_size  self.embedding = nn.Embedding(vocab_size, wordvec_size)  self.dropout = nn.Dropout(dropout)  self.rnn = nn.LSTM(wordvec_size, hidden_size, num_layers=num_layers, dropout=dropout)  # self.rnn = rnn_layer  self.linear = nn.Linear(self.hidden_size, vocab_size)  def forward(self, x, h0=None, c0=None):  # nn.Embedding 需要的类型 (IntTensor or LongTensor)        # 传过来的X是(batch_size, seq), embedding之后 是(batch_size, seq, vocab_size)  # nn.LSTM 支持的X默认为(seq, batch_size, vocab_size)  # 若想用(batch_size, seq, vocab_size)作参数, 则需要在创建self.embedding实例时指定batch_first=True  # 这里用(seq, batch_size, vocab_size) 作参数,所以先给x转置,再embedding,以便再将结果传给lstm  x = x.T  x.long()  x = self.embedding(x)  x = self.dropout(x)  outputs = self.dropout(outputs)  outputs = outputs.reshape(-1, self.hidden_size)  outputs = self.linear(outputs)  return outputs, (h0, c0)  def init_state(self, device, batch_size=1):  return (torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device),  torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device))

训练代码:

模型应用可以参考 Pytorch封装简单RNN模型,进行中文训练及文本预测 一文。

def start_train():  # device = torch.device("cpu")  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  print(f'\ndevice: {device}')  corpus, vocab = load_corpus("../data/COIG-CQIA/chengyu_qa.txt")  vocab_size = len(vocab)  wordvec_size = 100  hidden_size = 256  epochs = 1  batch_size = 50  learning_rate = 0.01  time_size = 4  max_grad_max_norm = 0.5  num_layers = 2  dropout = 0.5  dataset = make_dataset(corpus=corpus, time_size=time_size)  data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)  net = MyLSTM(vocab_size=vocab_size, wordvec_size=wordvec_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)  net.to(device)  # print(net.state_dict())  criterion = nn.CrossEntropyLoss()  criterion.to(device)  optimizer = optim.Adam(net.parameters(), lr=learning_rate)  writer = SummaryWriter('./train_logs')  # 随便定义个输入, 好使用add_graph  tmp = torch.randint(0, 100, size=(batch_size, time_size)).to(device)  h0, c0 = net.init_state(batch_size=batch_size, device=device)  writer.add_graph(net, [tmp, h0, c0])  loss_counter = 0  total_loss = 0  ppl_list = list()  total_train_step = 0  for epoch in range(epochs):  print('------------Epoch {}/{}'.format(epoch + 1, epochs))  for X, y in data_loader:  X, y = X.to(device), y.to(device)  # 这里batch_size=X.shape[0]是因为在加载数据时, DataLoader没有设置丢弃不完整的批次, 所以存在实际批次不满足设定的batch_size  h0, c0 = net.init_state(batch_size=X.shape[0], device=device)  outputs, (hn, cn) = net(X, h0, c0)  optimizer.zero_grad()  # y也变成 时间序列*批次大小的行数, 才和 outputs 一致  y = y.T.reshape(-1)  # 交叉熵的第二个参数需要LongTorch  loss = criterion(outputs, y.long())  loss.backward()  # 求完梯度之后可以考虑梯度裁剪, 再更新梯度  grad_clipping(net, max_grad_max_norm)  optimizer.step()  total_loss += loss.item()  loss_counter += 1  total_train_step += 1  if total_train_step % 10 == 0:  print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  writer.add_scalar('train_loss', loss.item(), total_train_step)  ppl = np.exp(total_loss / loss_counter)  ppl_list.append(ppl)  print(f'Epoch {epoch + 1} 结束, batch_loss_average: {total_loss / loss_counter}, perplexity: {ppl}')  writer.add_scalar('ppl', ppl, epoch + 1)  total_loss = 0  loss_counter = 0  torch.save(net.state_dict(), './save/epoch_{}_ppl_{}.pth'.format(epoch + 1, ppl))  writer.close()  return net, ppl_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/412562.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Threejs之OrbitControls轨道控制器

本文目录 前言一、Orbitcontrols(轨道控制器)1.1 基础使用1.2 代码演示 二、效果展示 前言 Orbitcontrols(轨道控制器)可以使得相机围绕目标进行轨道运动。 一、Orbitcontrols(轨道控制器) 1.1 基础使用 C…

【Python 千题 —— 基础篇】身份证隐藏的信息

Python 千题持续更新中 …… 脑图地址 👉:⭐https://twilight-fanyi.gitee.io/mind-map/Python千题.html⭐ 题目描述 题目描述 在一个用户信息管理系统中,你需要处理和验证用户提供的身份证号。编写一个程序来从用户信息字符串中提取和验证身份证号,并提供相应的处理方式…

图论----最小生成树讲解与相关题解

目前已更新系列 当前--图论----最小生成树讲解与相关题解 滑动窗口系列算法总结与题解一 算法系列----并查集总结于相关题解 图论---dfs系列 差分与前缀和总结与对应题解(之前笔试真的很爱考) 数论---质数判断、质因子分解、质数筛(埃氏…

在 Cilium CNI 集群上运行 vCluster 虚拟集群

上周在 KubeCon China 2024 大会上,我和社区伙伴们作为志愿者在 Cilium 项目展台与用户交流。有位用户询问 Cilium 是否能与 vCluster 集成,当时未能给出明确答复,特地回来后进行了测试。 答案是:在最新的 vCluster v0.20 中容器…

【Python篇】Python 类和对象:详细讲解(上篇)

文章目录 Python 类和对象:详细讲解1. 什么是类(Class)类的定义 2. 什么是对象(Object)创建对象 3. 属性和方法属性(Attributes)方法(Methods)在类中定义属性和方法使用对…

重生奇迹MU 小清新职业智弓MM

游戏中有一种令人迷醉的职业——智弓MM,她们以高超的射箭技能闻名于世。本文将为您介绍这个悠闲的小清新职业,在游戏中的特点以及如何成为一名出色的智弓MM。跟随我们一起探索这个奇妙而神秘的职业吧! 悠闲的游戏节奏是游戏的初衷之一&#…

52 mysql 启动过程中常见的相关报错信息

前言 我们这里主要是看一下 service mysql start, service mysql stop 的过程中的一些常见的错误问题 这些 也是之前经常碰到, 但是 每次都是 去搜索, 尝试 1, 2, 3, 4 去解决问题 但是 从来未曾思考过 这个问题到底是 怎么造成的 The server quit without updating PID fil…

【设计模式】创建型模式——抽象工厂模式

抽象工厂模式 1. 模式定义2. 模式结构3. 实现3.1 实现抽象产品接口3.2 定义具体产品3.3 定义抽象工厂接口3.4 定义具体工厂3.5 客户端代码 4. 模式分析4.1 抽象工厂模式退化为工厂方法模式4.2 工厂方法模式退化为简单工厂模式 5. 模式特点5.1 优点5.2 缺点 6. 适用场景6.1 需要…

用3点结构的s1顺序标定2点结构的s2顺序

在行列可自由变换的条件下,3点结构有6个 (A,B)---6*30*2---(0,1)(1,0) 让A分别是3a1,2,…,6,让B全是0。当收敛误差为7e-4,收敛199次取迭代次数平均值,得到 迭代 搜索难度 1 13913.2 1 2 …

客服系统简易版

整体架构解读 客服端和商城端都通过websocket连接到客服系统, 并定期维持心跳当客户接入客服系统时, 先根据策略选择在线客服, 然后再发送消息给客服 websocket实现 用netty实现websocket协议, 增加心跳处理的handler, 详见chat-server模块 客服路由规则 暂时仅支持轮询的…

视频结构化从入门到精通——视频结构化主要技术介绍

视频结构化主要技术 1 视频接入 “视频接入”是视频结构化管道的起点(SRC Point)视频接入是视频结构化处理的第一步,它涉及将视频数据从各种采集源获取到系统中进行进一步处理。视频接入的质量和稳定性对后续的数据处理、分析和应用至关重要…

【openwrt-21.02】T750 openwrt-21.02 Linux-5.4.238 input子系统----gpio-keys实现分析

input子系统 输入子系统是由设备驱动层(input driver)、输入核心层(input core)、输入事件处理层(input event handle)组成 input子系统架构图 gpio-keys gpio-keys是基于input子系统实现的一个通用按键驱动,该驱动也符合linux驱动实现模型,即driver和device分离模型.一…

毕设创新点之一:基于GD32/STM32的AI模型部署-github库

将AI模型成功部署到边缘MCU中,常常受限于MCU的计算峰值和内存峰值的限制,部署较为困难,目前有一个将AI算法MCU部署到GD32系列MCU中的宝藏的开源库。 项目网址:HomiKetalys/gd32ai-modelzoo: Provide deployable deep learning mo…

Vue.js 模板语法详解:插值表达式与指令使用指南

Vue.js 模板语法详解:插值表达式与指令使用指南 引言 简要介绍主题: Vue.js 是一个现代化的 JavaScript 框架,用于构建用户界面。Vue 的模板语法提供了直观且功能强大的工具,用于将数据与 DOM 绑定。本文将深入探讨 Vue.js 的两个…

Training language models to follow instructionswith human feedback

Abstract 将语言模型做得更大并不会自动提高它们遵循用户意图的能力。例如,大型语言模型可能会生成不真实、有毒或对用户不有帮助的输出。换句话说,这些模型并未与用户对齐(aligned)。本文展示了一种通过人类反馈来对齐语言模型与…

2024实战指南:四款全免费的数据恢复工具盘点!

在这个数字化的时代里,数据的安全至关重要。如果一不小心删除或丢失了重要数据应该怎么办呢?这几个全免费的数据恢复工具可以帮你解决问题,亲测好用哦! 第一款:福昕数据恢复 直达链接:www.pdf365.cn/foxi…

【并发编程】从AQS机制到同步工具类

AQS机制 Java 中常用的锁主要有两类,一种是 Synchronized 修饰的锁,被称为 Java 内置锁或监视器锁。另一种就是在 JUC 包中的各类同步器,包括 ReentrantLock(可重入锁)、Semaphore(信号量)、Co…

Android13 Launcher3 客制化Workspace页面指示器

需求:原生态的workspace页面指示器是个长条,不大好看,需要进行客制化 实现效果如图: 实现原理: 代码实现在WorkspacePageIndicator.java 布局在launcher.xml里 实现在WorkspacePageIndicator.java通过重写onDraw函数…

顺序循环队列

顺序循环队列 队头插入元素,队尾删除元素 本来应该判空和判断是否存满的条件都是:队头 队尾,但这样就没办法区分了,所以,就牺牲一个空间(比如长度为10,但只能存9个),这…

auto的使用场景

auto的两面性 合理使用auto 不仅可以减少代码量, 也会大大提高代码的可读性. 但是事情总有它的两面性 如果滥用auto, 则会让代码失去可读性 推荐写法 这里推荐两种情况下使用auto 一眼就能看出声明变量的初始化类型的时候 比如迭代器的循环, 用例如下 #include <iostre…