LSTM长短时记忆网络:推导与实现(pytorch)

LSTM长短时记忆网络:推导与实现(pytorch)

  • 背景
  • 推导
    • 遗忘门
    • 输入门
    • 输出门
  • LSTM的改进:GRU
  • 实现

背景

  人类不会每秒钟都从头开始思考。当你阅读这篇文章时,你会根据你对以前单词的理解来理解每个单词。你不会把所有东西都扔掉,重新开始思考。你的思想是持久的。

  传统的神经网络无法做到这一点,这是一个主要缺点。递归神经网络(Rnn)解决了这个问题,这是一种带有记忆的模型,可以将其认为是同一网络的多个副本,每个副本将消息传递给继任者。
在这里插入图片描述
  RNN的优势是,它们可能能够将先前的信息与当前任务联系起来,例如使用以前的视频帧为理解当前帧提供信息。但是在实际情况中,它不一定能做到这一点。

  有时我们只需要查看最近的信息即可执行当前任务。例如,考虑一个语言模型,它试图根据前一个单词来预测下一个单词。如果我们试图预测“云在天空中”中的最后一个词,我们不需要任何进一步的上下文,很明显下一个词将是天空。在这种情况下,相关信息与需要信息的地方之间的差距很小,RNN可以学习使用过去的信息。
  但在某些情况下,我们需要更多的背景信息。不妨试着预测经文中的最后一个字:“我在法国长大…我能说一口流利的法语。最近的信息表明,下一个词可能是一种语言的名称,但如果我们想缩小哪种语言的范围,我们需要从更远的地方开始了解法国的上下文。相关信息与需要信息的点之间的差距完全有可能变得非常大。
  不幸的是,随着这种差距的扩大,RNN变得无法学习连接信息。这就是传统RNN的缺点,很难处理长距离的依赖。

  长短期记忆网络(通常简称为“LSTM”)解决了这个问题,这是一种特殊的RNN,能够学习长期依赖关系。我们将推导LSTM中每一层的结构,并实现一个pytorch版本LSTM。

推导

  长短时记忆网络的思路很简单。传统RNN的隐藏层只有一个状态,即h,它对于短期的输入非常敏感,假如我们再增加一个状态,即c,让它来保存长期的状态,那么问题就解决了:
在这里插入图片描述
  新增加的状态c,称为单元状态(cell state)。我们把上图按照时间维度展开:
在这里插入图片描述
  在长短时记忆网络的前向计算中,通过门(gate)控制向量的变化。门实际上就是一层全连接层,它的输入是一个向量,输出是一个0到1之间的实数向量。那么门可以表示为: g ( x ) = σ ( W x + b ) g(x)=\sigma(Wx+b) g(x)=σ(Wx+b)  由于sigmod函数的性质,门的输出是0到1之间的实数向量,那么,当门输出为0时,任何向量与之相乘都会得到0向量;输出为1时,任何向量与之相乘都不会有任何改变。
  LSTM用两个门来控制单元状态c的内容,一个是遗忘门(forget gate),它决定了上一时刻的单元状态有多少保留到当前时刻ct;另一个是输入门(input gate),它决定了当前时刻网络的输入xt有多少保存到单元状态ct。LSTM用输出门(output gate)来控制单元状态ct有多少输出到LSTM的当前输出值ht。

遗忘门

  遗忘门通过门控制上一时刻的输入的单元状态 c t − 1 c_{t-1} ct1有多少被保留下来,门的权值通过上一时刻的输出值 h t − 1 h_{t-1} ht1和这一时刻的输入值 x t x_t xt得到,即: f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf)  这个过程可以被分解为 f t = W f h h t − 1 + W f h x t f_t=W_{fh}h_{t-1}+W_{fh}x_t ft=Wfhht1+Wfhxt  在我们对门的实现中,我们都遵循这种方式。下图显示了遗忘门的计算:在这里插入图片描述

输入门

  输入门决定了将当前的输入有多少被保留到c中,门的权值计算方式与遗忘门相同,即: i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi)  通过门值,我们计算当前输入的单元状态 c ^ t \hat c_t c^t,它是通过上一次的输出和本次输入计算的: c ^ t = t a n h ( W c ⋅ [ h t − 1 , x t ] + b c ) \hat c_t=tanh(W_c\cdot [h_{t-1},x_t]+b_c) c^t=tanh(Wc[ht1,xt]+bc)  下图显示了输入门的计算:
在这里插入图片描述
  到这里,我们可以计算当前时刻的单元状态 c t c_t ct,它是由上一次的单元状态 c t − 1 c_{t-1} ct1乘以遗忘门权值 f t f_t ft,再用当前输入的单元状态 c ^ t \hat c_t c^t乘以输入门权值 i t i_t it,再将两个积加和产生的: c t = f t ⋅ c t − 1 + i t ⋅ c ^ t c_t=f_t\cdot c_{t-1}+i_t\cdot \hat c_t ct=ftct1+itc^t  这样,我们就把LSTM关于当前的记忆 c ^ t \hat c_t c^t和长期的记忆 c t − 1 c_{t-1} ct1组合在一起,形成了新的单元状态 c t c_t ct。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。

输出门

  输出门的权值计算方式与上面两个门相同: o t = σ ( W o ⋅ [ h t − 1 , x t ] + b i ) o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_i) ot=σ(Wo[ht1,xt]+bi)  LSTM最终的输出,是由输出门和单元状态共同确定的: h t = o t ⋅ t a n h ( c t ) h_t=o_t\cdot tanh(c_t) ht=ottanh(ct)在这里插入图片描述
  至此,LSTM的推导就讲完了。

LSTM的改进:GRU

  LSTM也有许多缺点,因此提出了很多变体,GRU(Gated Recurrent Unit)就是比较成功的一种,针对LSTM有三个不同的门,参数较多,训练困难的缺点,GRU将LSTM中的输入门和遗忘门合二为一,称为更新门(update gate),控制前边记忆信息能够继续保留到当前时刻的数据量;另一个门称为重置门(reset gate),控制要遗忘多少过去的信息。其结构如下所示,读者可以自行对比:在这里插入图片描述

实现

  我们使用pytorch实现了一个LSTMlayer的模型:

class LstmLayer(torch.nn.Module):def __init__(self, bert_model, fea_dim, dropout):super(LstmLayer, self).__init__()self.fea_dim = fea_dim# 激活函数self.sigmod = nn.Sigmoid()self.tanh = nn.Tanh()# 遗忘门权重矩阵Wfh, Wfxself.Wfh = torch.nn.Linear(self.fea_dim, 1)self.Wfx = torch.nn.Linear(self.fea_dim, 1)# 输入门权重矩阵Wfh, Wfxself.Wih = torch.nn.Linear(self.fea_dim, 1)self.Wix = torch.nn.Linear(self.fea_dim, 1)# 单元状态更新权重矩阵Wch, Wcxself.Wch = torch.nn.Linear(self.fea_dim, self.fea_dim)self.Wcx = torch.nn.Linear(self.fea_dim, self.fea_dim)# 输出门权重矩阵Woh, Woxself.Woh = torch.nn.Linear(self.fea_dim, self.fea_dim)self.Wox = torch.nn.Linear(self.fea_dim, self.fea_dim)def forward(self,x_t,c_t,h_t):# 遗忘门fg = self.calc_gate(x_t,h_t, self.Wfx, self.Wfh,self.sigmod)c_t = fg * c_t# 输入门ig = self.calc_gate(x_t,h_t, self.Wix, self.Wih,self.sigmod)c_t_temp = ig * self.tanh(self.Wch(h_t) + self.Wcx(x_t))c_out = c_t_temp + c_t# 输出门og = self.calc_gate(x_t,h_t, self.Wox, self.Woh,self.sigmod)h_out = self.tanh(c_out) * ogreturn c_out,h_outdef calc_gate(self, x,h, Wx, Wh, activator):# 计算门权值gate_weight = Wx(x) + Wh(h)gate_out = activator(gate_weight)return gate_out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/335286.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32-12-OLED模块

STM32-01-认识单片机 STM32-02-基础知识 STM32-03-HAL库 STM32-04-时钟树 STM32-05-SYSTEM文件夹 STM32-06-GPIO STM32-07-外部中断 STM32-08-串口 STM32-09-IWDG和WWDG STM32-10-定时器 STM32-11-电容触摸按键 文章目录 1. OLED显示屏介绍2. OLED驱动原理3. OLED驱动芯片简介4…

v4l2抓取rv1126图像

0.准备工作 本文是基于正点原子的rv1126开发板使用mx415摄像头对不同节点的图像进行抓取 1.数据流向 图1 mx415采集到的数据为原始的拜尔格式(也就是raw格式),我们需要通过isp进行图像的调节才符合视觉,其中isp和ispp是两个处理的…

【大数据】Hadoop 2.X和1.X升级优化对比

目录 1.前言 2.hadoop 1.X的缺点和优化方向 3.解决NameNode的局限性 3.1.Hadoop HA 3.2.Haddop federation 4.yarn 5.周边组件 1.前言 本文是作者大数据系列中的一文,专栏地址: https://blog.csdn.net/joker_zjn/category_12631789.html?spm10…

网络侦察技术

网络侦察技术 收集的信息网络侦察步骤搜索引擎检索命令bing搜索引擎Baidu搜索引擎Shodan钟馗之眼(zoomeye) whois数据库:信息宝库查询注册资料 域名系统网络拓扑社交网络跨域拓展攻击 其它侦察手段社会工程学社会工程学常见形式Web网站查询 其它非技术侦察手段总结网…

GDPU 操作系统 天码行空13

文章目录 ❌ TODO:本文仅供参考,极有可能有误1.生产者消费者问题(信号量)💖 ProducerConsumerExample.java🏆 运行结果 💖 ProducerConsumerSelectiveExample.java🏆 运行结果 2.实现…

将四种算法的预测结果绘制在一张图中

​ 声明:文章是从本人公众号中复制而来,因此,想最新最快了解各类智能优化算法及其改进的朋友,可关注我的公众号:强盛机器学习,不定期会有很多免费代码分享~ 之前的一期推文中,我们推出了…

TREK高压发生器维修高压电源615-3-L-JX 615-3

美国TREK高压电源维修故障分析应注意两点: 故障分析检测和故障硬件更换,由高压电源故障和工作表现初步判断故障的类型和哪些硬件出了问题,初步判断缩小检测范围,通过排除法和更替新配件准确找到故障硬件。维修过程需要对trek电源维…

C语言学习笔记之指针(一)

目录 什么是指针? 指针和指针类型 指针的类型 指针类型的意义 指针-整数 指针的解引用 指针 - 指针 指针的关系运算 野指针 什么是野指针? 野指针的成因 如何规避野指针? 二级指针 什么是指针? 在介绍指针之前&#…

【ai】livekit:Agents 1 : Agents Framework 与 LiveKit 核心 API 原语

agents 官方文档LiveKit Agents LiveKit Agents is an end-to-end framework for building realtime, multimodal AI “agents” that interact with end-users through voice, video, and data channels. This framework allows you to build an agent using Python.是一个端到…

2024年6月1日(星期六)骑行禹都甸

2024年6月1日 (星期六)骑行禹都甸(韭葱花),早8:30到9:00,昆明氧气厂门口集合,9:30准时出发【因迟到者,骑行速度快者,可自行追赶偶遇。】 偶遇地点:昆明氧气厂门口集合 ,…

010-Linux磁盘介绍

文章目录 1、名词 2、类型 3、尺寸 4、接口/协议/总线 5、命名 6、分区方式 MBR分区 GPT分区 1、名词 磁盘是计算机主要的存储介质,可以存储大量的二进制数据,并且断电后也能保持数据不丢失。早期计算机使用的磁盘是软磁盘(Floppy D…

三方语言中调用, Go Energy GUI编译的dll动态链接库CEF

如何在其它编程语言中调用energy编译的dll动态链接库,以使用CEF 或 LCL库 Energy是Go语言基于LCL CEF开发的跨平台GUI框架, 具有很容易使用CEF 和 LCL控件库 interface 便利 示例链接 正文 为方便起见使用 python 调用 go energy 编译的dll 准备 系统&#x…

C++:vector的模拟实现

✨✨✨学习的道路很枯燥,希望我们能并肩走下来! 文章目录 目录 文章目录 前言 一、vector的模拟实现 1.1 迭代器的获取 1.2 构造函数和赋值重载 1.2.1 无参构造函数 1.2.2 有参构造函数(对n个对象的去调用他们的构造) 1.2.3 迭代器区…

【UnityShader入门精要学习笔记】第十五章 使用噪声

本系列为作者学习UnityShader入门精要而作的笔记,内容将包括: 书本中句子照抄 个人批注项目源码一堆新手会犯的错误潜在的太监断更,有始无终 我的GitHub仓库 总之适用于同样开始学习Shader的同学们进行有取舍的参考。 文章目录 使用噪声上…

亮相CCIG2024,合合信息文档解析技术破解大模型语料“饥荒”难题

近日,2024中国图象图形大会在古都西安盛大开幕。本届大会由中国图象图形学学会主办,空军军医大学、西安交通大学、西北工业大学承办,通过二十多场论坛、百余项成果,集中展示了生成式人工智能、大模型、机器学习、类脑计算等多个图…

Compose第一弹 可组合函数+Text

目标: 1.Compose是什么?有什么特征? 2.Compose的文本控件 一、Compose是什么? Jetpack Compose 是用于构建原生 Android 界面的新工具包。 Compose特征: 1)声明式UI:使用声明性的函数构建一…

opencascade 快速显示AIS_ConnectedInteractive源码学习

AIS_ConcentricRelation typedef PrsDim_ConcentricRelation AIS_ConcentricRelation AIS_ConnectedInteractive 简介 创建一个任意位置的另一个交互对象实例作为参考。这允许您使用连接的交互对象,而无需重新计算其表示、选择或图形结构。这些属性是从您的参考对…

蓝桥杯嵌入式国赛笔记(4):多路AD采集

1、前言 蓝桥杯的国赛会遇到多路AD采集的情况,这时候之前的单路采集的方式就不可用了,下面介绍两种多路采集的方式。 以第13届国赛为例 2、方法一(配置通道) 2.1 使用CubeMx配置 设置IN13与IN17为Single-ended 在Parameter S…

今日好料推荐(大数据湖体系规划)

今日好料推荐(大数据湖体系规划) 参考资料在文末获取,关注我,获取优质资源。 大数据湖体系规划 一、大数据湖简介 大数据湖(Data Lake)是一个集中式的存储库,用于存储来自各种来源的结构化和…

蓝桥杯杨辉三角

PREV-282 杨辉三角形【第十二届】【蓝桥杯省赛】【B组】 (二分查找 递推): 解析: 1.杨辉三角具有对称性: 2.杨辉三角具有一定规律 通过观察发现,第一次出现的地方一定在左部靠右的位置,所以从…