什么是长短期记忆网络?

一、概念

        长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门)来控制信息的流动。其中,每个门都是一个神经网络层,用于决定哪些信息应该被保留,哪些信息应该被丢弃。LSTM的核心是细胞状态(cell state),它通过这些门的控制来更新和传递信息。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,C_{t-1}为前一个时间步的细胞状态向量,C_{t}为当前时间步的细胞状态变量,f_{t}为当前时间步的遗忘门向量,i_{t}为当前时间步的输入门向量,\bar{C_{t}}为当前时间步的候选细胞状态向量,o_{t}为当前时间步的输出门向量,W_{f},W_{i},W_{C},W_{o}分别为各门的权重矩阵,b_{f},b_{i},b_{C},b_{o}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。LSTM的核心内容包括以下几个部分:

1、遗忘门(Forget Gate)

        遗忘门决定细胞状态中哪些信息需要被遗忘。通过sigmoid激活函数,遗忘门的输出在0到1之间,表示每个细胞状态元素被保留的比例。

f_{t} = \sigma(W_{f} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{f})

2、输入门(Input Gate)

        输入门决定哪些新的信息需要被写入细胞状态。通过sigmoid激活函数,输入门的输出在0到1之间,表示每个候选细胞状态元素被写入的比例。候选细胞状态通过tanh激活函数生成,表示新的信息。

i_{t} = \sigma(W_{i} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{i})

\bar{C}_{t} = tanh(W_{C} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{C})

3、细胞状态更新

        细胞状态结合遗忘门和输入门的结果进行更新。遗忘门的输出与前一个时间步的细胞状态相乘,表示保留的旧信息。输入门的输出与候选细胞状态相乘,表示写入的新信息。两者相加得到当前时间步的细胞状态。

C_{t} = f_{t} \ast C_{t-1}+i_{t} \ast \bar{C}_{t}

4、输出门(Output Gate)

        输出门决定细胞状态的哪些部分将作为输出。通过sigmoid激活函数,输出门的输出在0到1之间,表示每个细胞状态元素被输出的比例。细胞状态通过tanh激活函数进行非线性变换,然后与输出门的输出相乘,得到当前时间步的隐藏状态。

o_{t} = \sigma(W_{o} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{o})

h_{t} = o_{t} \ast tanh(C_{t})

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成正弦波数据
def generate_sine_wave(seq_length, num_samples):x = np.linspace(0, num_samples, num_samples)y = np.sin(x)data = []for i in range(len(y) - seq_length):data.append(y[i:i+seq_length+1])return np.array(data)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out# 超参数设置
seq_length = 50
num_samples = 1000
input_size = 1
hidden_size = 50
output_size = 1
num_layers = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 5
test_size = 0.2  # 测试集占比# 生成数据
data = generate_sine_wave(seq_length, num_samples)
X = data[:, :-1]
y = data[:, -1]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)# 转换为Tensor
X_train = torch.tensor(X_train.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_train = torch.tensor(y_train.reshape(-1, output_size), dtype=torch.float32)
X_test = torch.tensor(X_test.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_test = torch.tensor(y_test.reshape(-1, output_size), dtype=torch.float32)# 创建数据加载器
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型、损失函数和优化器
model = LSTMModel(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()for i, (inputs, labels) in enumerate(train_loader):outputs = model(inputs)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():predicted = []actual = []for inputs, labels in test_loader:outputs = model(inputs)predicted.extend(outputs.numpy())actual.extend(labels.numpy())# 绘制结果
plt.plot(actual, label='Actual data')
plt.plot(predicted, label='Predicted data')
plt.legend()
plt.show()

四、总结

        LSTM能够捕捉长时间依赖关系,使得模型在处理长序列数据时表现得比标准的RNN更好。但由于LSTM的计算依赖于前一个时间步的输出,这使得这样的网络结构难以并行化,在处理大规模数据时的效率较低。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/9400.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LangChain的开发流程

文章目录 LangChain的开发流程开发密钥指南3种使用密钥的方法编写一个取名程序 LangChain表达式 LangChain的开发流程 为了更深人地理解LangChain的开发流程,本文将以构建聊天机器人为实际案例进行详细演示。下图展示了一个设计聊天机器人的LLM应用程序。 除了Wb服务…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.20 极值追踪:高效获取数据特征的秘诀

1.20 极值追踪:高效获取数据特征的秘诀 1.20.1 目录 #mermaid-svg-RBxy2YCCN23ydzFu {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-RBxy2YCCN23ydzFu .error-icon{fill:#552222;}#mermaid-svg-RBxy2YC…

Vscode的AI插件 —— Cline

简介 vscode的一款AI辅助吃插件,主要用来辅助创建和编辑文件,探索大型项目,使用浏览器并执行终端命令(需要多个tokens),可以使用模型上下文协议(MCP)来创建新工具并扩展自己(比较慢…

C++ unordered_map和unordered_set的使用,哈希表的实现

文章目录 unordered_map,unorder_set和map ,set的差异哈希表的实现概念直接定址法哈希冲突哈希冲突举个例子 负载因子将关键字转为整数哈希函数除法散列法/除留余数法 哈希冲突的解决方法开放定址法线性探测二次探测 开放定址法代码实现 哈希表的代码 un…

c#使用log4Net配置日志文件

1.# 写一个通用类 LogHelper using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using log4net;namespace WindowsFormsApplication22 {public class LogHelper{static ILog mylog LogManager.GetLogge…

WebSocket 详解:全双工通信的实现与应用

目录 一、什么是 WebSocket?(简介) 二、为什么需要 WebSocket? 三、HTTP 与 WebSocket 的区别 WebSocket 的劣势 WebSocket 的常见应用场景 WebSocket 握手过程 WebSocket 事件处理和生命周期 一、什么是 WebSocket&#xf…

Qt Ribbon使用实例

采用SARibbon创建简单的ribbon界面 实例代码如下所示&#xff1a; 1、头文件&#xff1a; #pragma once #include <SARibbonBar.h> #include "SARibbonMainWindow.h" class QTextEdit; class SAProjectDemo1 : public SARibbonMainWindow { Q_OBJECT pub…

认识小程序的基本组成结构

1.基本组成结构 2.页面的组成部分 3.json配置文件 4.app.json文件(全局配置文件&#xff09; 5.project.config.json文件 6.sitemap.json文件 7.页面的.json配置文件 通过window节点可以控制小程序的外观

JVM--类加载器

概念 类加载器&#xff1a;只参与加载过程中的字节码获取并加载到内存中的部分&#xff1b;java虚拟机提供给应用程序去实现获取类和接口字节码数据的一种技术&#xff0c;也就是说java虚拟机是允许程序员写代码去获取字节码信息 类加载是加载的第一步&#xff0c;主要有以下三…

51单片机开发:定时器中断

目标&#xff1a;利用定时器中断&#xff0c;每隔1s开启/熄灭LED1灯。 外部中断结构图如下图所示&#xff0c;要使用定时器中断T0&#xff0c;须开启TE0、ET0。&#xff1a; 系统中断号如下图所示&#xff1a;定时器0的中断号为1。 定时器0的工作方式1原理图如下图所示&#x…

Greenplum临时表未清除导致库龄过高处理

1.问题 Greenplum集群segment后台日志报错 2.回收库龄 master上执行 vacuumdb -F -d cxy vacuumdb -F -d template1 vacuumdb -F -d rptdb 3.回收完成后检查 仍然发现segment还是有库龄报警警告信息发出 4.检查 4.1 在master上检查库年龄 SELECT datname, datfrozen…

小程序-视图与逻辑

前言 1. 声明式导航 open-type"switchTab"如果没有写这个&#xff0c;因为是tabBar所以写这个&#xff0c;就无法跳转。路径开始也必须为斜线 open-type"navigate"这个可以不写 现在开始实现后退的效果 现在我们就在list页面里面实现后退 2.编程式导航…

Kotlin开发(六):Kotlin 数据类,密封类与枚举类

引言 想象一下&#xff0c;你是个 Kotlin 开发者&#xff0c;敲着代码忽然发现业务代码中需要一堆冗长的 POJO 类来传递数据。烦得很&#xff1f;别急&#xff0c;Kotlin 贴心的 数据类 能帮你自动生成 equals、hashCode&#xff0c;直接省时省力&#xff01;再想想需要多种状…

games101-作业2

图形管线 Vertex Processing 对顶点进行加工&#xff0c;使其变换到屏幕空间坐标。 Triangle Processing 将加工后的顶点组装成三角形&#xff0c;用于下一步的光栅化。 void rst::rasterizer::draw(pos_buf_id pos_buffer, ind_buf_id ind_buffer, col_buf_id col_buffer, Pr…

Baklib引领企业内容中台建设的新思路与应用案例

内容概要 在数字化转型的浪潮中&#xff0c;内容中台的概念逐渐成为企业实现高效运营的重要基础。内容中台不仅是信息资产的集中管理平台&#xff0c;更是企业在应对快速变化市场需求时的一种敏捷响应机制。通过搭建内容中台&#xff0c;企业能够有效整合各类资源&#xff0c;…

准备知识——旋转机械的频率和振动基础

旋转频率&#xff0c;也称为转速或旋转速率&#xff08;符号ν&#xff0c;小写希腊字母nu&#xff0c;也作n&#xff09;&#xff0c;是物体绕轴旋转的频率。其国际单位制单位是秒的倒数(s −1 )&#xff1b;其他常见测量单位包括赫兹(Hz)、每秒周期数(cps) 和每分钟转数(rpm)…

Java 大视界 -- Java 大数据在生物信息学中的应用与挑战(67)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…

【NLP251】NLP RNN 系列网络

NLP251 系列主要记录从NLP基础网络结构到知识图谱的学习 &#xff11;.原理及网络结构 &#xff11;.&#xff11;&#xff32;&#xff2e;&#xff2e; 在Yoshua Bengio论文中( http://proceedings.mlr.press/v28/pascanu13.pdf )证明了梯度求导的一部分环节是一个指数模型…

Unbutu虚拟机+eclipse+CDT编译调试环境搭建

问题1: 安装CDT&#xff0c;直接Help->eclipse Market space-> 搜cdt , install&#xff0c;等待重启即可. 问题2&#xff1a;C变量不识别vector ’could not be resolved 这是库的头文件没加好&#xff0c;右键Properties->C Build->Enviroment&#xff0c;增加…

关于opencv环境搭建问题:由于找不到opencv_worldXXX.dll,无法执行代码,重新安装程序可能会解决此问题

方法一&#xff1a;利用复制黏贴方法 打开opencv文件夹目录找到\opencv\build\x64\vc15\bin 复制该目录下所有文件&#xff0c;找到C:\Windows\System32文件夹&#xff08;注意一定是C盘&#xff09;黏贴至该文件夹重新打开VS。 方法二&#xff1a;直接配置环境 打开opencv文…