用deepseek学大模型08-长短时记忆网络 (LSTM)

deepseek.com 从入门到精通长短时记忆网络(LSTM),着重介绍的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据, 模型应用场景和优缺点,及如何改进解决及改进方法数据推导。

从入门到精通长短时记忆网络 (LSTM)

参考:长短时记忆网络(LSTM)在序列数据处理中的优缺点分析
LSTM


1. LSTM 核心机制

LSTM 通过门控机制(遗忘门、输入门、输出门)和细胞状态(Cell State)解决 RNN 的梯度消失问题。

核心公式(时间步 t t t):

  1. 遗忘门(Forget Gate):
    f t = σ ( W f [ h t − 1 , x t ] + b f ) \mathbf{f}_t = \sigma\left( \mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f \right) ft=σ(Wf[ht1,xt]+bf)
  2. 输入门(Input Gate):
    i t = σ ( W i [ h t − 1 , x t ] + b i ) \mathbf{i}_t = \sigma\left( \mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i \right) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C [ h t − 1 , x t ] + b C ) \tilde{\mathbf{C}}_t = \tanh\left( \mathbf{W}_C [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_C \right) C~t=tanh(WC[ht1,xt]+bC)
  3. 细胞状态更新
    C t = f t ⊙ C t − 1 + i t ⊙ C ~ t \mathbf{C}_t = \mathbf{f}_t \odot \mathbf{C}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{C}}_t Ct=ftCt1+itC~t
  4. 输出门(Output Gate):
    o t = σ ( W o [ h t − 1 , x t ] + b o ) \mathbf{o}_t = \sigma\left( \mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o \right) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ⊙ tanh ⁡ ( C t ) \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{C}_t) ht=ottanh(Ct)

2. 目标函数与损失函数
  • 目标函数:最小化预测与真实值的差异(监督学习)。
  • 损失函数(以分类任务交叉熵为例):
    L = − 1 T ∑ t = 1 T ∑ c = 1 C y ^ t , c log ⁡ ( y t , c ) L = -\frac{1}{T} \sum_{t=1}^T \sum_{c=1}^C \mathbf{\hat{y}}_{t,c} \log(\mathbf{y}_{t,c}) L=T1t=1Tc=1Cy^t,clog(yt,c)
    其中 C C C为类别数, y ^ \mathbf{\hat{y}} y^为真实标签的 one-hot 编码。

3. 梯度下降与数学推导

LSTM 的梯度反向传播通过细胞状态 C t \mathbf{C}_t Ct和门控机制稳定梯度流动。

标量形式推导(以遗忘门 f t \mathbf{f}_t ft为例):
∂ L ∂ f t = ∂ L ∂ h t ⋅ ∂ h t ∂ C t ⋅ ∂ C t ∂ f t \frac{\partial L}{\partial \mathbf{f}_t} = \frac{\partial L}{\partial \mathbf{h}_t} \cdot \frac{\partial \mathbf{h}_t}{\partial \mathbf{C}_t} \cdot \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} ftL=htLCthtftCt
其中:
∂ C t ∂ f t = C t − 1 ⊙ f t ⊙ ( 1 − f t ) \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} = \mathbf{C}_{t-1} \odot \mathbf{f}_t \odot (1 - \mathbf{f}_t) ftCt=Ct1ft(1ft)

矩阵形式推导(链式法则):
∂ L ∂ W f = ∑ t = 1 T ( δ f , t ⋅ [ h t − 1 , x t ] T ) \frac{\partial L}{\partial \mathbf{W}_f} = \sum_{t=1}^T \left( \delta_{f,t} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t]^T \right) WfL=t=1T(δf,t[ht1,xt]T)
其中 δ f , t \delta_{f,t} δf,t为遗忘门的梯度误差:
δ f , t = ∂ L ∂ f t ⊙ σ ′ ( ⋅ ) \delta_{f,t} = \frac{\partial L}{\partial \mathbf{f}_t} \odot \sigma'(\cdot) δf,t=ftLσ()


4. PyTorch 代码案例
import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 数据生成:正弦波 + 噪声
time = torch.arange(0, 100, 0.1)
data = torch.sin(time) + 0.1 * torch.randn(len(time))# 转换为序列数据(窗口长度=20)
def create_sequences(data, seq_length=20):X, y = [], []for i in range(len(data)-seq_length):X.append(data[i:i+seq_length])y.append(data[i+seq_length])return torch.stack(X).unsqueeze(-1), torch.stack(y).unsqueeze(-1)X, y = create_sequences(data)
X_train, y_train = X[:800], y[:800]  # 划分训练集和测试集
X_test, y_test = X[800:], y[800:]# 定义 LSTM 模型
class LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=64, output_size=1):super().__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, (h_n, c_n) = self.lstm(x)  # out: (batch, seq_len, hidden_size)out = self.fc(out[:, -1, :])    # 取最后一个时间步return outmodel = LSTMModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练
epochs = 100
train_loss = []
for epoch in range(epochs):optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, y_train)loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 梯度裁剪optimizer.step()train_loss.append(loss.item())# 可视化训练损失
plt.plot(train_loss)
plt.title("Training Loss")
plt.show()# 预测
model.eval()
with torch.no_grad():train_pred = model(X_train)test_pred = model(X_test)# 绘制结果
plt.figure(figsize=(12, 5))
plt.plot(data.numpy(), label="True Data")
plt.plot(range(20, 820), train_pred.numpy(), label="Train Predictions")
plt.plot(range(820, len(data)), test_pred.numpy(), label="Test Predictions")
plt.legend()
plt.show()

5. 应用场景与优缺点
  • 应用场景
    • 时间序列预测(股票价格、天气)
    • 自然语言处理(文本生成、机器翻译)
    • 语音识别
  • 优点
    • 解决长程依赖问题
    • 通过门控机制稳定梯度流动
    • 可处理变长序列
  • 缺点
    • 计算复杂度高(参数多)
    • 对短序列可能过拟合
    • 训练时间较长

6. 改进方法及数学推导
  1. GRU(门控循环单元)
    简化 LSTM,合并遗忘门和输入门:
    z t = σ ( W z [ h t − 1 , x t ] ) \mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}, \mathbf{x}_t]) zt=σ(Wz[ht1,xt])
    r t = σ ( W r [ h t − 1 , x t ] ) \mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}, \mathbf{x}_t]) rt=σ(Wr[ht1,xt])
    h ~ t = tanh ⁡ ( W [ r t ⊙ h t − 1 , x t ] ) \tilde{\mathbf{h}}_t = \tanh(\mathbf{W} [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) h~t=tanh(W[rtht1,xt])
    h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t ht=(1zt)ht1+zth~t

  2. 双向 LSTM(Bi-LSTM)
    同时捕捉前向和后向依赖:
    h t → = LSTM ( x t , h t − 1 → ) \overrightarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overrightarrow{\mathbf{h}_{t-1}}) ht =LSTM(xt,ht1 )
    h t ← = LSTM ( x t , h t + 1 ← ) \overleftarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overleftarrow{\mathbf{h}_{t+1}}) ht =LSTM(xt,ht+1 )
    h t = [ h t → , h t ← ] \mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}, \overleftarrow{\mathbf{h}_t}] ht=[ht ,ht ]

  3. 注意力机制
    增强对关键时间步的关注:
    α t = softmax ( v T tanh ⁡ ( W h h t + W s s ) ) \alpha_t = \text{softmax}(\mathbf{v}^T \tanh(\mathbf{W}_h \mathbf{h}_t + \mathbf{W}_s \mathbf{s})) αt=softmax(vTtanh(Whht+Wss))
    c = ∑ t = 1 T α t h t \mathbf{c} = \sum_{t=1}^T \alpha_t \mathbf{h}_t c=t=1Tαtht


7. 关键改进的数学验证(以 GRU 为例)
  • 梯度稳定性
    GRU 的更新门 z t \mathbf{z}_t zt控制历史信息的保留比例,梯度可沿两条路径传播:
    ∂ h t ∂ h t − 1 = ( 1 − z t ) + z t ⊙ ∂ h ~ t ∂ h t − 1 \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}} = (1 - \mathbf{z}_t) + \mathbf{z}_t \odot \frac{\partial \tilde{\mathbf{h}}_t}{\partial \mathbf{h}_{t-1}} ht1ht=(1zt)+ztht1h~t
    避免传统 RNN 的连乘梯度。

通过上述内容,您可全面掌握 LSTM 的理论基础、实际实现及优化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/21335.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

以ChatGPT为例解析大模型背后的技术

目录 1、大模型分类 2、为什么自然语言处理可计算? 2.1、One-hot分类编码(传统词表示方法) 2.2、词向量 3、Transformer架构 3.1、何为注意力机制? 3.2、注意力机制在 Transformer 模型中有何意义? 3.3、位置编…

鸿道Intewell操作系统:赋能高端装备制造,引领国产数控系统迈向新高度

在当今全球制造业竞争日益激烈的时代,高端装备制造作为国家核心竞争力的重要组成部分,其发展水平直接影响着一个国家的综合实力。而CNC数控系统,作为高端装备制造的“大脑”,对于提升装备的精度、效率和智能化水平起着关键作用。鸿…

mac开发环境配置笔记

1. 终端配置 参考: Mac终端配置笔记-CSDN博客 2. 下载JDK 到 oracle官网 下载jdk: oracle官网 :Java Downloads | Oraclemac的芯片为Intel系列下载 x64版本的jdk;为Apple Mx系列使用 Arm64版本;oracle官网下载时报错:400 Bad R…

【Python爬虫(29)】爬虫数据生命线:质量评估与监控全解

【Python爬虫】专栏简介:本专栏是 Python 爬虫领域的集大成之作,共 100 章节。从 Python 基础语法、爬虫入门知识讲起,深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑,覆盖网页、图片、音频等各类数据爬取&#xff…

大模型工具大比拼:SGLang、Ollama、VLLM、LLaMA.cpp 如何选择?

简介:在人工智能飞速发展的今天,大模型已经成为推动技术革新的核心力量。无论是智能客服、内容创作,还是科研辅助、代码生成,大模型的身影无处不在。然而,面对市场上琳琅满目的工具,如何挑选最适合自己的那…

测评雷龙出品的CS SD NAND贴片式TF卡

一、前言 在现代科技飞速发展的背景下,存储解决方案的创新与进步成为了推动各行各业发展的重要力量。这篇文章讲解雷龙公司出品的CS SD NAND贴片式TF卡的深度测评。这款产品不仅以其小巧精致的设计脱颖而出,更凭借其卓越的性能和可靠性,在众…

Hadoop一 HDFS分布式文件系统

一 分布式文件存储 了解为什么海量数据需要使用分布式存储技术 100T数据太大,单台服务器无法承担。于是: 分布式服务器集群 靠数量取胜,多台服务器组合,才能Hold住,如下 分布式不仅仅是解决了能存的问题&#xff…

windows下docker使用笔记

目录 镜像的配置 镜像的拉取 推荐镜像源列表(截至2025年2月测试有效) 配置方法 修改容器名字 如何使用卷 创建不同的容器,每个容器中有不同的mysql和java版本(不推荐) 1. 安装 Docker Desktop(Win…

1005 K 次取反后最大化的数组和(贪心)

文章目录 题目[](https://leetcode.cn/problems/maximize-sum-of-array-after-k-negations/)算法原理源码总结 题目 如上图,k是取反的次数,在数组【4,-1,3】中,当k 1,把-2取反为2,和为9;在数组…

java毕业设计之医院门诊挂号系统(源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的医院门诊挂号系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 医院门诊挂号系统的主要使用者…

深入学习解析:183页可编辑PPT华为市场营销MPR+LTC流程规划方案

华为终端正面临销售模式转型的关键时刻,旨在通过构建MPRLTC项目,以规避对运营商定制的过度依赖,并探索新的增长路径。项目核心在于建设一套全新的销售流程与IT系统,支撑双品牌及自有品牌的战略发展。 项目总体方案聚焦于四大关键议…

JUC并发—8.并发安全集合一

大纲 1.JDK 1.7的HashMap的死循环与数据丢失 2.ConcurrentHashMap的并发安全 3.ConcurrentHashMap的设计介绍 4.ConcurrentHashMap的put操作流程 5.ConcurrentHashMap的Node数组初始化 6.ConcurrentHashMap对Hash冲突的处理 7.ConcurrentHashMap的并发扩容机制 8.Concu…

Cython学习笔记1:利用Cython加速Python运行速度

Cython学习笔记1:利用Cython加速Python运行速度 CythonCython 的核心特点:利用Cython加速Python运行速度1. Cython加速Python运行速度原理2. 不使用Cython3. 使用Cython加速(1)使用pip安装 cython 和 setuptools 库(2&…

DApp 开发入门指南

DApp 开发入门指南 🔨 1. DApp 基础概念 1.1 什么是 DApp? 去中心化应用(DApp)是基于区块链的应用程序,特点是: 后端运行在区块链网络前端可以是任何框架使用智能合约处理业务逻辑数据存储在区块链上 1…

基于Spring Security 6的OAuth2 系列之二十 - 高级特性--令牌交换(Token Exchange)

之所以想写这一系列,是因为之前工作过程中使用Spring Security OAuth2搭建了网关和授权服务器,但当时基于spring-boot 2.3.x,其默认的Spring Security是5.3.x。之后新项目升级到了spring-boot 3.3.0,结果一看Spring Security也升级…

瑞芯微RV1126部署YOLOv8全流程:环境搭建、pt-onnx-rknn模型转换、C++推理代码、错误解决、优化、交叉编译第三方库

目录 1 环境搭建 2 交叉编译opencv 3 模型训练 4 模型转换 4.1 pt模型转onnx模型 4.2 onnx模型转rknn模型 4.2.1 安装rknn-toolkit 4.2.2 onn转成rknn模型 5 升级npu驱动 6 C++推理源码demo 6.1 原版demo 6.2 增加opencv读取图片的代码 7 交叉编译x264 ffmepg和op…

【开源】编译器,在线操作

目录 1. 思绪思维导图:simple mind map2. Markdown:md-editor-v33. 文档:wangEditor4. 电子表格:Luckysheet5. 幻灯片:PPTist6. 白板:excalidraw7. 流程图:drawio 1. 思绪思维导图:…

跳表(Skip List)详解

一、什么是跳表? 跳表是一种基于有序链表的高效数据结构,通过建立多级索引实现快速查询。它在平均情况下支持O(log n)时间复杂度的搜索、插入和删除操作,性能接近平衡树,但实现更为简单。 二、核心原理 1. 层级结构 底层为完整…

【Quest开发】全身跟踪

软件:Unity 2022.3.51f1c1、vscode、Meta XR All in One SDK V72 硬件:Meta Quest3 最终效果:能像meta的操作室沉浸场景一样根据头盔移动来推断用户姿势,实现走路、蹲下、手势匹配等功能 需要借助UnityMovement这个包 GitHub …

25年2月通信基础知识补充:多普勒频移与多普勒扩展、3GPP TDL信道模型

看文献过程中不断发现有太多不懂的基础知识,故长期更新这类blog不断补充在这过程中学到的知识。由于这些内容与我的研究方向并不一定强相关,故记录不会很深入请见谅。 【通信基础知识补充7】25年2月通信基础知识补充1 一、多普勒频移与多普勒扩展傻傻分不…