【论文复现】LSTM长短记忆网络

LSTM

  • 前言
  • 网络架构
    • 总线
    • 遗忘门
    • 记忆门
    • 记忆细胞
    • 输出门
  • 模型定义
    • 单个LSTM神经元的定义
    • LSTM层内结构的定义
  • 模型训练
  • 模型评估
  • 代码细节
    • LSTM层单元的首尾的处理
    • 配置Tensorflow的GPU版本

前言

LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能,同时还经常用于做时间序列。由于LSTM的原版论文相关版权问题,这里以colah大佬的博客为基础进行讲解。之前写过一篇Tensorflow中的LSTM详解,但是原理部分跟代码部分的联系并不紧密,实践性较强但是如果想要进行更加深入的调试就会出现原理性上面的问题,因此特此作文解决这个问题,想要用LSTM这个有趣的模型做出更加好的机器学习效果😊。

网络架构

LSTM框架图
这张图展示了LSTM在整体结构,下面就开始分部分介绍中间这个东东。

总线

在这里插入图片描述
这条是总线,可以实现神经元结构的保存或者更改,如果就是像上图一样一条总线贯穿不做任何改变,那么就是不改变细胞状态。那么如果想要改变细胞状态怎么办?可以通过来实现,这里的门跟高中生物中学的神经兴奋阈值比较像,用数学来表示就是sigmoid函数或者其他的激活函数,当门的输入达到要求,门就会打开,允许当前门后面的信息“穿过”门改变主线上面传递的信息,如果把每一个神经元看成一个时间节点,那么从上一个时间节点传到下一个时间节点过程中的门的开启与关闭就实现了时间序列数据的信息传递。
在这里插入图片描述

遗忘门

在这里插入图片描述
首先是遗忘门,这个门的作用是决定从上一个神经元传输到当前神经元的数据丢弃的程度,如果经过sigmoid函数以后输出0表示全部丢弃,输出1表示全部保留,这个层的输入是旧的信息和当前的新信息。

σ \sigma σ:sigmoid函数
W f W_f Wf:权重向量
b f b_f bf:偏置项,决定丢弃上一个时间节点的程度,如果是正数,表示更容易遗忘,如果是负数,表示比较容易记忆
h t − 1 h_{t-1} ht1:上一个时刻的输入
x t x_t xt:当前层的输入

记忆门

在这里插入图片描述
接下来是记忆门,这个门决定要记住什么信息,同时决定按照什么程度记住上一个状态的信息。

i t i_t it:在时间步t时刻的输入门激活值,计算方法跟上面的遗忘门是一样的,只是目的不一样,这里是记忆
C ~ t \tilde{C}_{t} C~t:表示上一个时刻的信息和当前时刻的信息的集合,但是是规则化到[-1,1]这个范围内了的

记忆细胞

在这里插入图片描述
有了上面的要记忆的信息和要丢弃的信息,记忆细胞的功能就可以得到实现,用 f t f_t ft这个标量决定上一个状态要遗忘什么,用 i t i_t it这个标量决定上一个状态要记住什么以及当前状态的信息要记住什么。这样就形成了一个记忆闭环了。

输出门

在这里插入图片描述
最后,在有了记忆细胞以后不仅仅不要将当前细胞状态记住,还要将当前的信息向下一层继续传输,实现公式中的状态转移。

o t o_t ot:跟前面的门公式都一样,但是功能是决定输出的程度
h t h_t ht:将输出规范到[-1,1]的区间,这里有两个输出的原因是在构建LSTM网络的时候需要有纵向向上的那个 h t h_t ht,然而在当前层的LSTM的神经元之间还是首尾相接的😍。

模型定义

单个LSTM神经元的定义


# 定义单个LSTM单元
# 定义单个LSTM单元
class My_LSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size# 初始化门的权重和偏置,由于每一个神经元都有自己的偏置,所以在定义单元内部定义self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bf = nn.Parameter(torch.Tensor(hidden_size))self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bi = nn.Parameter(torch.Tensor(hidden_size))self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bo = nn.Parameter(torch.Tensor(hidden_size))self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bg = nn.Parameter(torch.Tensor(hidden_size))# 初始化输出层的权重和偏置self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))self.b = nn.Parameter(torch.Tensor(output_size))# 用于计算每一种权重的函数def cal_weight(self, input, weight, bias):return F.linear(input, weight, bias)# x是输入的数据,数据的格式是(batch, seq_len, input_size),包含的是batch个序列,每个序列有seq_len个时间步,每个时间步有input_size个特征def forward(self, x):# 初始化隐藏层和细胞状态h = torch.zeros(1, 1, self.hidden_size).to(x.device)c = torch.zeros(1, 1, self.hidden_size).to(x.device)# 遍历每一个时间步for i in range(x.size(1)):input = x[:, i, :].view(1, 1, -1) # 取出每一个时间步的数据# 计算每一个门的权重f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遗忘门i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 输入门o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 输出门C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候选值# 更新细胞状态c = f * c + i * C_# 更新隐藏层h = o * torch.tanh(c) # 将输出标准化到-1到1之间output = self.cal_weight(h, self.W, self.b) # 计算输出return output

LSTM层内结构的定义

class My_LSTMNetwork(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTMNetwork, self).__init__()self.hidden_size = hidden_sizeself.lstm = My_LSTM(input_size, hidden_size)  # 使用自定义的LSTM单元self.fc = nn.Linear(hidden_size, output_size)  # 定义全连接层def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))  # LSTM层的前向传播out = self.fc(out[:, -1, :])  # 全连接层的前向传播return out

模型训练

history = model.fit(trainX, trainY, batch_size=64, epochs=50, validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型评估

为了更加直观展示,这里用画图的方法进行结果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代码细节

LSTM层单元的首尾的处理

  • 首部:由于第一个节点不用接受来自上一个节点的输入,不需要有输入,当然也有一些是添加标识。

  • 尾部:由于已经进行到当前层的最后一个节点,因此输出只需要向下一层进行传递而不用向下一个节点传递,添加标识也是可以的。

配置Tensorflow的GPU版本

这一篇写的比较好,我自己的硬件环境如下图所示,需要的可以借鉴一下,当然也可以在我提供的代码链接直接用我给的environment.yml一键构建环境😃。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/332463.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kubenetes中K8S的命名空间状态异常强制删除Terminating的ns

查看ns状态为异常: 查看ns为monitoring的状态为Termingating状态 使用方法一: kubectl delete ns monitoring --force --grace-period0 使用方法二: kubectl get ns monitoring -o json > monitoring.json 修改删除文件中的"kubern…

go select 原理

编译器会使用如下的流程处理 select 语句: 将所有的 case 转换成包含 channel 以及类型等信息的 runtime.scase 结构体。调用运行时函数 runtime.selectgo 从多个准备就绪的 channel 中选择一个可执行的 runtime.scase 结构体。通过 for 循环生成一组 if 语句&…

动态规划之背包问题中如何确定遍历顺序的问题-组合or排列?

关于如何确定遍历顺序 322. 零钱兑换中,本题求钱币最小个数,那么钱币有顺序和没有顺序都可以,都不影响钱币的最小个数。 所以本题并不强调集合是组合还是排列。 如果求组合数就是外层for循环遍历物品,内层for遍历背包。 如果求…

MySQL事务篇1:事物的四大特性(ACID)、三类数据读取问题与隔离级别

一、什么是事务? MySQL的事务(Transaction)是一组由数据库管理系统(DBMS)执行的一个或多个SQL语句的集合,这些SQL语句作为一个单独的工作单元执行。事务的主要目的是确保数据库的一致性和完整性&#xff0c…

民国漫画杂志《时代漫画》第16期.PDF

时代漫画16.PDF: https://url03.ctfile.com/f/1779803-1248612470-6a05f0?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps:资源来源网络!

深入了解 Golang 多架构编译:交叉编译最佳实践

随着软件开发领域的不断发展,我们面临着越来越多的挑战,其中之一是如何在不同的平台和架构上部署我们的应用程序。Golang(Go)作为一种现代化的编程语言,具有出色的跨平台支持,通过其强大的多架构编译功能&a…

网络模型-NQA与网络协议联动

一、NQA定义 网络质量分析NQA(Network QualityAnalysis)是一种实时的网络性能探测和统计技术,可以对响应时间、网络抖动、丢包率等网络信息进行统计。NOA能够实时监视网络0oS,在网络发生故障时进行有效的故障诊断和定位。 部署IPv4静态路由与BFD…

【Text2SQL 经典模型】TypeSQL

论文:TypeSQL: Knowledge-Based Type-Aware Neural Text-to-SQL Generation ⭐⭐⭐ Code: TypeSQL | GitHub 一、论文速读 本论文是在 SQLNet 网络上做的改进,其思路也是先预先构建一个 SQL sketch,然后再填充 slots 从而生成 SQL。 论文发…

Debezium+Kafka:Oracle 11g 数据实时同步至 DolphinDB 解决方案

随着越来越多用户使用 DolphinDB,各式各样的应用场景对 DolphinDB 的数据接入提出了不同的要求。部分用户需要将 Oracle 11g 的数据实时同步到 DolphinDB 中来,以满足在 DolphinDB 中实时使用数据的需求。本篇教程将介绍使用 Debezium 来实时捕获和发布 …

浅谈Docker容器的网络通信原理

文章目录 1、回顾容器概念2、容器网络3、容器与主机之间的网络连通4、交换机的虚拟实现---虚拟网桥(Bridge)5、Docker 守护进程daemon管理容器网络 1、回顾容器概念 我们知道容器允许我们在同一台宿主机(电脑)上运行多个服务&…

【UE数字孪生学习笔记】 使用DataSmith对模型快速导入 UE5.3.2使用unreal DataSmith文件

声明:部分内容来自于b站,慕课,公开课等的课件,仅供学习使用。如有问题,请联系删除。 部分内容来自UE官方文档,博客等 UE5.3.2使用 3D Max 导出的unreal DataSmith文件 1. 去UE官网下载DataSmith导出器并导…

Wpf 使用 Prism 实战开发Day25

首页待办事项及备忘录添加功能 一.修改待办事项和备忘录逻辑处理类,即AddMemoViewModel和AddTodoViewModel 1.AddMemoViewModel 逻辑处理类,添加View视图数据要绑定的实体类 Model public class AddMemoViewModel :BindableBase,IDialogHostAware{public AddMemoV…

代码随想录-Day20

654. 最大二叉树 给定一个不重复的整数数组 nums 。 最大二叉树 可以用下面的算法从 nums 递归地构建: 创建一个根节点,其值为 nums 中的最大值。 递归地在最大值 左边 的 子数组前缀上 构建左子树。 递归地在最大值 右边 的 子数组后缀上 构建右子树。 返回 nums…

[杂项]优化AMD显卡对DX9游戏(天谕)的支持

目录 关键词平台说明背景RDNA 1、2、3 架构的显卡支持游戏一、 优化方法1.1 下载 二、 举个栗子(以《天谕》为例)2.1 下载微星 afterburner 软件 查看游戏内信息(可跳过)2.2 查看D3D9 帧数2.3 关闭游戏,替换 dll 文件2…

DBAPI怎么进行数据格式转换

DBAPI如何进行数据格式的转换 假设现在有个API,根据学生id查询学生信息,访问API查看数据格式如下 {"data":[{"name":"Michale","phone_number":null,"id":77,"age":55}],"msg"…

AIGC-常见图像质量评估MSE、PSNR、SSIM、LPIPS、FID、CSFD,余弦相似度----理论+代码

持续更新和补充中…多多交流! 参考: 图像评价指标PNSR和SSIM 函数 structural_similarity 图片相似度计算方法总结 MSE和PSNR MSE: M S E 1 m n ∑ i 0 m − 1 ∑ j 0 n − 1 [ I ( i , j ) − K ( i , j ) ] 2 MSE\frac{1}{mn}\sum_{i0}^{m-1}\sum_{j0}^{n-1}[…

车载电子电器架构 —— 应用软件开发(下)

车载电子电器架构 —— 应用软件开发(下) 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证…

【全开源】赛事报名系统源码(Fastadmin+ThinkPHP和Uniapp)

基于FastadminThinkPHP和Uniapp开发的赛事报名系统,包含个人报名和团队报名、成绩查询、成绩证书等。 构建高效便捷的赛事参与平台 一、引言:赛事报名系统的重要性 在举办各类赛事时,一个高效便捷的报名系统对于组织者和参与者来说都至关重…

全域运营是本地生活的下半场?新的创业风口来了?

随着全域概念的兴起,全域运营赛道也逐渐进入人们的视野之中,甚至有业内人士预测,全域运营将会是本地生活下半场的大趋势。 之所以这么说,是因为全域运营作为包含了公域和私域内所有运营业务的新模式,不仅能同时做所有本…

仿《Q极速体育》NBACBA体育直播吧足球直播综合体育直播源码

码名称:仿《Q极速体育》NBACBA体育直播吧足球直播综合体育直播源码 开发环境:帝国cms7.5 空间支持:phpmysql 仿《Q极速体育》NBACBA体育直播吧足球直播综合体育直播源码自动采集 - 我爱模板网源码名称:仿《Q极速体育》NBACBA体育直…