动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。
在这里插入图片描述
在这里插入图片描述

2.优点
  1. 解决梯度消失问题:通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比:在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):# 输入输出的维度大小num_inputs = num_outputs = vocab_size# 正态分布初始化权重def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)def three():return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重torch.zeros(num_hiddens, device=device))  # 偏置W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重b_q = torch.zeros(num_outputs, device=device)  # 输出偏置# 将所有参数附加到参数列表中params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)  # 设置参数需要梯度return params# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元# LSTM前向传播
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = state  # 隐藏状态和记忆元outputs = []for X in inputs:# 输入门I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)# 遗忘门F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 输出门O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 候选记忆元C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 更新记忆元C = F * C + I * C_tilda# 更新隐藏状态H = O * torch.tanh(C)# 计算输出Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:

在这里插入图片描述
简洁实现:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/373698.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大型企业如何整合集成全域数据、解决数据孤岛难题?

今天,我们说一下大型企业全域数据的整合集成问题。 通常,中大型企业和集团公司拥有大量多源异构的数据存储资源,如数据仓库、数据湖以及分布于分子公司和混合多云平台的业务系统,通过传统物理集中统一数据资产管理的方式难度高&a…

Springboot 设置个性化banner

在 Spring Boot 中自定义 banner 的方法有几种,可以通过以下步骤来实现: 1、使用文本文件作为 banner 在 src/main/resources 目录下创建一个名为 banner.txt 的文件。 编辑这个文件,输入想要显示的文本。确保文本中包含换行符和空格…

Android OpenGL ES 离屏幕渲染1——EGL环境的创建,以及基础概念的理解

创建EGL上下文、配置EGL环境、创建EGL DISPLAY 什么是EGL: 由于OpenGL ES并不负责窗口管理以及上下文管理,该职责由各个平台自行完成;在Android平台下OpenGL ES的上下文环境是依赖EGL的API进行搭建的。 对于EGL这个框架,谷歌已经提…

MES系统在工业4.0时代的应用与前景

在工业4.0时代,制造执行系统(MES)在现代制造业中的作用变得愈发重要。一个优秀的MES系统不仅可以提高生产效率和质量控制,还可以帮助企业实现智能制造,提升市场竞争力。 一、MES系统的基本概念与功能 制造执行系统&a…

补码一位乘法原理(布斯编码详讲)

最近在看补码乘法的时候,感觉到很奇怪的一点,那就是补码的一位乘法,就是上网查了大量的资料都没有理解到它真正的原理,总感觉还是不会。那么,补码乘法的原理到底是什么呢?而让我们一直困惑的点是哪里呢&…

宿州降本 提质 增效 数据采集监控平台提高生产自动化水平

在当今竞争激烈的市场环境中,企业追求降本、提质、增效已成为发展的关键。而我们的[数据采集监控平台名称]数据采集监控平台,正是助力企业实现这一目标的强大工具。 LP-SCADA数据采集监控平台是工业4.0中主要的数据采集系统之一,主要针对产线…

nginx的知识面试易考点

Nginx概念 Nginx 是一个高性能的 HTTP 和反向代理服务。其特点是占有内存少,并发能力强,事实上nginx的并发能力在同类型的网页服务器中表现较好。 Nginx 专为性能优化而开发,性能是其最重要的考量指标,实现上非常注重效率&#…

for nested data item, row-key is required.报错解决

今天差点被一个不起眼的bug搞到吐,就是在给表格设置row-key的时候,一直设置不成功,一直报错缺少row-key,一共就那两行代码 实在是找不到还存在什么问题... 先看下报错截图... 看下代码 我在展开行里面用到了一个表格 并且存放表格…

Python数据分析案例50——基于EEMD-LSTM的石油价格预测

案例背景 很久没更新时间序列预测有关的东西了。 之前写了很多CNN-LSTM,GRU-attention,这种神经网络之内的不同模型的缝合,现在写一个模态分解算法和神经网络的缝合。 虽然eemd-lstm已经在学术界被做烂了,但是还是很多新手小白或…

昇思MindSpore学习笔记3-03热门LLM及其他AI应用--基于MobileNetv2的垃圾分类

摘要: MindSpore AI框架使用MobileNetv2模型开发垃圾分检代码。检测本地图像中的垃圾物体,保存检测结果到文件。记录了开发过程和步骤,包括环境准备、数据下载、加载和预处理、模型搭建、训练、测试、推理应用等。 1、实验目的 了解垃圾分…

DDoS攻击详解

DDoS 攻击,其本质是通过操控大量的傀儡主机或者被其掌控的网络设备,向目标系统如潮水般地发送海量的请求或数据。这种行为的目的在于竭尽全力地耗尽目标系统的网络带宽、系统资源以及服务能力,从而致使目标系统无法正常地为合法用户提供其所应…

aop的几种动态代理以及简单案例(1)

Sping AOP是通过动态代理模式实现的,具体有两种实现方式,一种是基于Java原生的动态代理,一种是基于cglib的动态代理。 1.jdk动态代理 1.1创建需要被代理的方法接口 public interface TargetInteface {void method1();void method2();int me…

到底哪款护眼大路灯好?五款适合学生用的护眼落地灯分享

到底哪款护眼大路灯好?影响青少年近视的最大“杀手”竟是学习环境光的影响。而对于这种情形,尤其是对于需要长时间用眼的学生群体和伏案工作者来说,护眼大路灯简直就是必备神器,但有人会问,我手机打开一搜就出现了那么…

有没有适合低预算党的主食冻干?希喂主食冻干真实喂养体验分享

铲屎官们好,今天来和大家聊聊一款让我喂出花来的主食冻干——希喂CPMR2.0大橙罐。作为一个注重猫咪身体健康和幸福感的铲屎官,怎么会不喂主食冻干。铲屎官们都致力于找到一款适合自家猫咪和平衡自己预算的主食冻干,我也做了不少尝试&#xff…

Apache配置与应用(企业网站架构部署与优化)

本章结构 如果要修改以上文件中的内容,想要生效,需要在主配置文件中能够扫描到这个默认文件的修改: 文件在: Apache 连接保持 Apache 的访问控制 针对IP地址的限制缺陷是不可预知性,需要事先直到对方的IP才能进行基于…

深度学习模型分布式训练

单机单卡训练 单机多卡训练 使用torch.nn.DataParallel方式,修改简单,但单进程效率慢 使用DDP方式,多进程效率高,推荐 多机多卡 模型并行 示例:

如何让自动化测试框架更自动化?

一、引言 ​对于大厂的同学来说,接口自动化是个老生常谈的话题了,毕竟每年的MTSC大会议题都已经能佐证了,不是大数据测试,就是AI测试等等(越来越高大上了)。不可否认这些专项的方向是质量智能化发展的方向&…

Redis基本数据结构

Redis基本数据结构 ​Redis​是C​语言开发的一个开源的(遵从BSD​协议)高性能键值对(key-value​)的内存数据库,可以用作数据库、缓存、消息中间件等。它是一种NoSQL​(not-only sql​,泛指非…

几种不同的方式禁止IP访问网站(PHP、Nginx、Apache设置方法)

1、PHP禁止IP和IP段访问 <?//禁止某个IP$banned_ip array ("127.0.0.1",//"119.6.20.66","192.168.1.4");if ( in_array( getenv("REMOTE_ADDR"), $banned_ip ) ){die ("您的IP禁止访问&#xff01;");}//禁止某个IP段…

中国式报表怎么做?用这款免费可视化工具快速搞定复杂报表

1. 什么是中国式报表&#xff1f; 中国式报表是一种中国独有的复杂报表&#xff0c;有格式复杂、计算复杂、数据来源复杂等特点&#xff0c;并且还有多样化的功能要求&#xff0c;例如图形、联动、回填等。因此许多国外报表工具在制作中国式报表方便表现得有些“水土不服”&am…