【深度学习】LSTM、BiLSTM详解

文章目录

  • 1. LSTM简介:
  • 2. LSTM结构图:
  • 3. 单层LSTM详解
  • 4. 双层LSTM详解
  • 5. BiLSTM
  • 6. Pytorch实现LSTM示例
  • 7. nn.LSTM参数详解

1. LSTM简介:

    LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通过使用门控单元来控制信息的流动,从而缓解RNN中的梯度消失和梯度爆炸的问题。LSTM的核心是三个门:输入门遗忘门输出门

遗忘门: 遗忘门的作用是决定哪些信息从记忆单元中遗忘,它使用sigmoid激活函数,可以输出在0到1之间的值,可以理解为保留信息的比例。
输入门: 作用是决定哪些新信息被存储在记忆单元中
输出门: 输出门决定了下一个隐藏状态,即生成当前时间步的输出并传递到下一时间步
记忆单元:负责长期信息的存储,通过遗忘门和输入门的相互作用,记忆单元能够学习如何选择性地记住或忘记信息

2. LSTM结构图:

在这里插入图片描述

涉及到的计算公式如下:
在这里插入图片描述

3. 单层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为1。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    LSTM单元包含三个输入参数x、c、h;首先t1时刻作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始从c(0)和h(0)都是0矩阵,计算完成后,第一个LSTM单元输出一组h(t1)\c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

4. 双层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为2。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    第二层LSTM没有输入参数x(t1)、x(t2)、x(t3);所以我们将第一层LSTM输出的h(t1)、h(t2)、h(t3)作为第二层LSTM的输入x(t1)、x(t2)、x(t3)。第一个时间步输入的初始c(0)和h(0)都为0矩阵,计算完成后,第一个时间步输出新的一组h(t1)、c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

5. BiLSTM

单层的BiLSTM其实就是2个LSTM,一个正向去处理序列,一个反向去处理序列,处理完后,两个LSTM的输出会拼接起来。
在这里插入图片描述

6. Pytorch实现LSTM示例

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out

7. nn.LSTM参数详解

pytorch官方定义:

CLASS torch.nn.LSTM(
        input_size,
        hidden_size,
        num_layers=1,
        bias=True,
        batch_first=False,
        dropout=0.0,
        bidirectional=False,
        proj_size=0,
        device=None,
        dtype=None
    )

input_size – 输入 x 中预期的特征数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置 num_layers=2 表示将两个 LSTM 堆叠在一起形成一个 stacked LSTM,其中第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1
bias – 如果 False,则该层不使用偏差权重 b_ih 和 b_hh。默认值:True
batch_first – 如果 True,则输入和输出张量将以 (batch, seq, feature) 而不是 (seq, batch, feature) 的形式提供。请注意,这并不适用于隐藏状态或单元状态。有关详细信息,请参见下面的输入/输出部分。默认值:False
dropout – 如果非零,则在除最后一层之外的每个 LSTM 层的输出上引入一个 Dropout 层,其 dropout 概率等于 dropout。默认值:0
bidirectional – 如果 True,则变为双向 LSTM。默认值:False
proj_size – 如果 > 0,则将使用具有相应大小的投影的 LSTM。默认值:0

对于输入序列每一个元素,每一层都会进行以下计算:
在这里插入图片描述
网络输入:
在这里插入图片描述

网络输出:
在这里插入图片描述

本文参考:https://blog.csdn.net/qq_34486832/article/details/134898868
https://pytorch.ac.cn/docs/stable/generated/torch.nn.LSTM.html#

LSTM每层的输出都要经过全连接层吗,还是直接对隐藏层进行输出?
通过在代码中对lstm的输出进行print输出:

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))print(out)print(hn)print(cn)# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out
if __name__ == "__main__":input_dim = 3        # 输入特征的维度hidden_dim = 4       # 隐藏层的维度num_layers = 1       # LSTM 层的数量output_dim = 1       # 输出特征的维度lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)batch_size = 1seq_length = 10input_tensor = torch.randn(batch_size, seq_length, input_dim).to(device)output = lstm(input_tensor)

通过对LSTM网络的输出我们可以看到,out的最后一层与最后一层隐藏层hn一致,说明并未经过全连接层,而是直接输出隐藏层
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/469988.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用ookii-dialogs-wpf在WPF选择文件夹时能输入路径

在进行WPF开发时,System.Windows.Forms.FolderBrowserDialog的选择文件夹功能不支持输入路径: 希望能够获得下图所示的选择文件夹功能: 于是,通过NuGet中安装Ookii.Dialogs.Wpf包,并创建一个简单的工具类: …

【leetcode练习·二叉树】用「分解问题」思维解题 II

本文参考labuladong算法笔记[【强化练习】用「分解问题」思维解题 II | labuladong 的算法笔记] 技巧一 类似于判断镜像二叉树、翻转二叉树的问题,一般也可以用分解问题的思路,无非就是把整棵树的问题(原问题)分解成子树之间的问…

Qt 编写插件plugin,支持接口定义信号

https://blog.csdn.net/u014213012/article/details/122434193?spm1001.2014.3001.5506 本教程基于该链接的内容进行升级,在编写插件的基础上,支持接口类定义信号。 环境:Qt5.12.12 MSVC2017 一、创建项目 新建一个子项目便于程序管理【…

PaaS云原生:分布式集群中如何构建自动化压测工具

场景 测试环境中,压测常常依赖环境中的各种工具获取基础信息,而这些工具可能集中在某个中控机上,此时想打造的自动化工具的运行模式是: 通过中控机工具获取压测所需的基本信息在中控机部署压测工具,实际压测任务分发…

关于sass在Vue3中编写bem框架报错以及警告问题记录

在编写完bem框架后 在vite.config.ts文件进行预编译处理时,报错的错误 1. 处理方式:使用新版api, 如图: 2. 处理方式:使用 use 替换掉 import, 如图: 3. 处理方式:使用路径别名&am…

内置RTK北斗高精度定位的4G执法记录仪、国网供电服务器记录仪

内置RTK北斗高精度定位的4G执法记录仪、国网供电服务器记录仪BD311R 发布时间: 2024-10-23 11:28:42 一、 产品图片: 二、 产品特性: 4G性能:支持2K超高清图传,数据传输不掉帧,更稳定。 独立北…

【前端】深入浅出的React.js详解

React 是一个用于构建用户界面的 JavaScript 库,由 Facebook 开发并维护。随着 React 的不断演进,官方文档也在不断更新和完善。本文将详细解读最新的 React 官方文档,涵盖核心概念、新特性、最佳实践等内容,帮助开发者更好地理解…

【Elasticsearch入门到落地】1、初识Elasticsearch

一、什么是Elasticsearch Elasticsearch(简称ES)是一款非常强大的开源搜索引擎,可以帮助我们从海量数据中快速找到需要的内容。它使用Java编写,基于Apache Lucene来构建索引和提供搜索功能,是一个分布式、可扩展、近实…

扫雷游戏代码分享(c基础)

hi , I am 36. 代码来之不易👍👍👍 创建两个.c 一个.h 1:test.c #include"game.h"void game() {//创建数组char mine[ROWS][COLS] { 0 };char show[ROWS][COLS] { 0 };char temp[ROWS][COLS] { 0 };//初始化数…

ORA-01092 ORA-14695 ORA-38301

文章目录 前言一、MAX_STRING_SIZE--12C 新特性扩展数据类型 varchar2(32767)二、恢复操作1.尝试恢复MAX_STRING_SIZE参数为默认值2.在upgrade模式下执行utl32k.sql 前言 今天客户发来一个内部测试库数据库启动截图报错,描述是“上午出现服务卡顿,然后重…

ODOO学习笔记(3):Odoo和Django的区别是什么?

Odoo和Django都是基于Python的开源框架,但它们的设计目标和用途有所不同: 设计目标和用途: Odoo:Odoo是一个企业资源规划(ERP)系统,它提供了一套完整的商业管理软件,包括会计、库存…

零基础玩转IPC之——海思平台实现P2P远程传输实验(基于TUTK,国科君正全志海思通用)

老规矩,先做实验测试。以本店Hi3516EV200\GK7205开发板为例,其他开发板操作类似。 将源码包p2p-h264.tgz放到虚拟机,解压,编译 tar -jxvf p2p-h264.tgz cd p2p-h264 make clean make 得到可执行文件p2p-h264 启动开发板&…

如何理解DDoS安全防护在企业安全防护中的作用

DDoS安全防护在安全防护中扮演着非常重要的角色。DDoS(分布式拒绝服务)攻击是一种常见的网络攻击,旨在通过向目标服务器发送大量请求,以消耗服务器资源并使其无法正常运行。理解DDoS安全防护的作用,可以从以下几个方面…

Python如何从HTML提取img标签下的src属性

目录 前提准备步骤1. 解析HTML内容2. 查找所有的img标签3. 提取src属性 完整代码 前提准备 在处理网页数据时,我们经常需要从HTML中提取特定的信息,比如图片的URL。 这通常通过获取img标签的src属性来实现。 在开始之前,你需要确保已经安装…

Redis主从复制(replication)

文章目录 是什么作用使用案例实操主从复制原理和工作流程slave启动,同步初请首次连接,全量复制心跳持续,保持通信进入平稳,增量复制从机下线,重连续传 复制的缺点 是什么 主从复制,master以写为主&#xf…

Android OpenGL ES详解——纹理:纹理过滤GL_NEAREST和GL_LINEAR的区别

目录 一、概念 1、纹理过滤 2、邻近过滤 3、线性过滤 二、邻近过滤和线性过滤的区别 三、源码下载 一、概念 1、纹理过滤 当纹理被应用到三维物体上时,随着物体表面的形状和相机视角的变化,会导致纹理在渲染过程中出现一些问题,如锯齿…

记录日志中logback和log4j2不能共存的问题

本文章记录设置两个日志时候,控制台直接报错 标黄处就是错误原因:1. SLF4J(W):类路径包含多个SLF4J提供程序。 SLF4J(W):找到提供程序[org.apache.logging.slf4j. net]。 SLF4J(W):找到提供程序[ch.qos.log .classi…

【PGCCC】Postgresql Toast 原理

前言 上篇博客讲述了 postgresql 如何存储变长数据,它的应用主要是在 toast 。Toast 在存储大型数据时,会将它存储在单独的表中(称为 toast 表)。因为 postgresql 的 tuple(行数据)是存在在 Page 中的&…

C指针创建三维数组

定义的时候变量的位置就是最后一个星号的位置 int*** matrix3d_int(int nz, int nrh, int nch) {int*** matrix (int***)malloc(nz * sizeof(int**));for (int z 0; z < nz; z) {matrix[z] (int**)malloc(nrh * sizeof(int*));for (int y 0; y < nrh; y) {matrix[z][…

window下安装rust 及 vscode配置

安装 安装mingw64 &#xff08;c语言环境 选择posix-ucrt&#xff09; ucrt:通用c运行时库配置mingw64/bin的路径到环境变量中在cmd窗口中输入命令 "gcc -v" 4. 下载Rust安装程序 安装 Rust - Rust 程序设计语言 5. 配置rustup和cargo目录 &#xff08;cargo是包管…