[PyTorch][chapter 46][LSTM -1]

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建


一  背景简介:

       1.1  RNN

         RNN 忽略o_t,L_t,y_t 模型可以简化成如下

      

       

          图中Rnn Cell 可以很清晰看出在隐藏状态h_t=f(x_t,h_{t-1})

            得到 h_t后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的h_{t+1}

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             f_t=\sigma(W_fh_{t-1}+U_{t}x_t+b_f)

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i)

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    a_t=\widetilde{c_t}=tanh(W_a h_{t-1}+U_ax_t+b_a)

           2.4 更新: 记忆单元

               c_t=f_t \odot c_{t-1}+i_t \odot a_t

             2.5  更新: 输出门

                决定是否使用隐藏值

                 o_t=\sigma(W_oh_{t-1}+U_ox_t+b_0)  

           2.6. 隐藏状态

                h_t=o_t \odot tanh(c_t)

           2.7  模型输出

                  \hat{y_t}=\sigma(Vh_t+b)

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况


三  LSTM 反向传播推导

      3.1 定义两个\delta_t

             \delta_h^t=\frac{\partial L}{\partial h_t}

            \delta_c^t=\frac{\partial L}{\partial C_t}

    3.2  定义损失函数

            损失函数L(t)分为两部分: 

             时刻t的损失函数 l(t)

             时刻t后的损失函数L(t+1)

              L(t)=\left\{\begin{matrix} l(t)+L(t+1), if: t<T\\ l(t), if: t=T \end{matrix}\right.

      3.3 最后一个时刻\tau

              

 这里面要注意这里的o^{\tau}= Vh_{\tau}+c

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   

   dl= tr((\frac{\partial L^{\tau}}{\partial h^{\tau}})^Tdh^{\tau})  ... 公式1: 微分和迹的关系

       =tr((\delta_h^{\tau})^Tdh^{\tau})

     因为

    h^{\tau}=o^{\tau} \odot tanh(c^{\tau})

   dh_T=o^{\tau}\odot(d(tanh (c^{\tau})))

           =o^{\tau} \odot (1-tanh^2(c^{\tau})) \odot dc^{\tau}

     带入上面公式1:

      dl= tr((\delta_h^{\tau})^T (o^{\tau}\odot(1-tanh^2(c^{\tau}))\odot dc^{\tau})

           =tr((\delta_h^{\tau} \odot o^{\tau} \odot(1-tanh^2(c^{\tau}))^Tdc^{\tau})

    所以

3.4   链式求导过程

       求导结果:

 

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

 ------------------------------------------------------------   

 第一项: 

             

         h_{t+1}=o_{t+1}\odot tanh(c_{t+1})

         o_{t+1}=\sigma(W_oh_t+U_ox_{t+1}+b_0)

        设 a_{t+1}=W_oh_t+U_ox_{t+1}+b_0

           则    \frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial a_{t+1}}\frac{\partial a_{t+1}}{\partial h_{t}}

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    \frac{\partial h_{t+1}}{\partial o_{t+1}}=diag(tanh(c^{t+1})) 是一个对角矩阵

                  o=\begin{bmatrix} \sigma(a_1)\\ \sigma(a_2) \\ .... \\ \sigma(a_n) \end{bmatrix}

                 \frac{\partial o_{t+1}}{\partial a_{t+1}}=diag(o_{t+1}\odot(1-o_{t+1}))

                 \frac{\partial a_{t+1}}{\partial h_{t}}=W_o

                 几个连乘起来就是第一项

               

第二项

    c_{t+1}=f_{t+1}\odot c_t+i_{t+1}\odot a_{t+1}

   f_{t+1}=\sigma(W_fh_t+U_tx_{t+1}+b_f)

   i_{t+1}=\sigma(W_ih_t+U_i x_{t+1}+b_i)

  a_{t+1}=tanh(W_a h_t +U_ax_t +b_a)

参考:

   h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

其中:

\frac{\partial h_{t+1}}{\partial c^{t+1}}=diag(o^{t+1}\odot (1-tanh^2(c^{t+1}))

\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial c_{t+1}}\frac{\partial c_{t+1}}{\partial f_{t+1}}\frac{\partial f_{t+1}}{\partial h_{t}}

 \frac{\partial c_{t+1}}{\partial f_{t+1}}=diag(c^{t})

 \frac{\partial a_{t+1}}{\partial h_{t}}=diag(f_t \odot(1-f_t))W_f

其它也是相似,就有了上面的求导结果


四  为什么能解决梯度消失

    

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如f_{t}=1,\hat{c_t}=0 时候,\frac{\partial c_t}{\partial c_{t-1}}=1,不会存在梯度消失。

       


五 模型的搭建

   

    我们最后发现:

    O_t,C_t,H_t 的维度必须一致,都是hidden_size

    通过C_t,则 I_t,F_t,\tilde{c} 最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023@author: chengxf2
"""import torch
from torch import nn
from d21 import torch as d21def normal(shape,devices):data = torch.randn(size= shape, device=devices)*0.01return datadef get_lstm_params(input_size, hidden_size,categorize_size,devices):#隐藏门参数W_xf= normal((input_size, hidden_size), devices)W_hf = normal((hidden_size, hidden_size),devices)b_f = torch.zeros(hidden_size,devices)#输入门参数W_xi= normal((input_size, hidden_size), devices)W_hi = normal((hidden_size, hidden_size),devices)b_i = torch.zeros(hidden_size,devices)#输出门参数W_xo= normal((input_size, hidden_size), devices)W_ho = normal((hidden_size, hidden_size),devices)b_o = torch.zeros(hidden_size,devices)#临时记忆单元W_xc= normal((input_size, hidden_size), devices)W_hc = normal((hidden_size, hidden_size),devices)b_c = torch.zeros(hidden_size,devices)#最终分类结果参数W_hq = normal((hidden_size, categorize_size), devices)b_q = torch.zeros(categorize_size,devices)params =[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q]for param in params:param.requires_grad_(True)return paramsdef init_lstm_state(batch_size, hidden_size, devices):cell_init = torch.zeros((batch_size, hidden_size),device=devices)hidden_init = torch.zeros((batch_size, hidden_size),device=devices)return (cell_init, hidden_init)def lstm(inputs, state, params):[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q] = params    (H,C) = stateoutputs= []for x in inputs:#input gateI = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)C = F*C+I*C_tmpH = O*torch.tanh(C)Y = (H@W_hq)+b_qoutputs.append(Y)return torch.cat(outputs, dim=0),(H,C)def main():batch_size,num_steps =32, 35train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)if __name__ == "__main__":main()


 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83956.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mac-右键-用VSCode打开

1.点击访达&#xff0c;搜索自动操作 2.选择快速操作 3.执行shell脚本 替换代码如下&#xff1a; for f in "$" doopen -a "Visual Studio Code" "$f" donecommand s保存会出现一个弹框&#xff0c;保存为“用VSCode打开” 5.使用

Dockerfile 简单实战

将flask项目打包成镜像 1. 准备flask文件 创建 app.py 文件&#xff0c;内容如下 from flask import Flask app Flask(__name__)app.route(/) def hello_world():return Hello Worldif __name__ __main__:app.run(host0.0.0.0, port8000, debugTrue) 并开启外网访问&#xf…

C#--设计模式之单例模式

单例模式大概是所有设计模式中最简单的一种&#xff0c;如果在面试时被问及熟悉哪些设计模式&#xff0c;你可能第一个答的就是单例模式。 单例模式的实现分为两种&#xff1a; 饿汉式&#xff1a;在静态构造函数执行时就立即实例化。懒汉式&#xff1a;在程序执行过程中第一…

vue3报错

这是因为eslint对代码的要求严格导致的&#xff0c;可以在package.json里面删掉"eslint:recommended"&#xff0c;然后重启就可以正常运行了

长城汽车正式进军东盟市场,多款智能新能源亮相印尼车展

长城汽车在2023年印尼国际车展&#xff08;GAIKINDO Indonesia International Auto Show&#xff09;揭幕GWM品牌系列车型&#xff0c;包括坦克500 HEV、哈弗H6 HEV、哈弗JOLION HEV以及欧拉好猫。这一战略旨在进一步打入印尼市场。 长城汽车宣布将正式进军东盟市场&#xff0c…

jumpserver命令记录膨胀问题

一.背景 jumpserver堡垒机针对只是接管ssh来说&#xff0c;正常操作Linux的指令记录应该不会太多&#xff0c;每天有个几千条都已经算很多了。所以默认jumpserver采用MySQL作为存储介质本身也没啥问题。但是我们使用jumpserver对【MySQL应用】进行了托管&#xff0c;导致查询SQ…

HttpRunner自动化测试之httprunner运行方式

httprunner运行方式&#xff1a; httprunner在进行接口测试的时候&#xff0c;有两种运行方式 方式一&#xff1a;通过命令行&#xff08;CLI&#xff09;运行&#xff0c;核心命令如下 hrun&#xff1a;httprunner的缩写&#xff0c;功能与httprunner完全相同 例&#xff1a…

培训报名小程序-订阅消息发送

目录 1 创建API2 获取模板参数3 编写自定义代码4 添加订单编号5 发送消息6 发布预览 我们上一篇讲解了小程序如何获取用户订阅消息授权&#xff0c;用户允许我们发送模板消息后&#xff0c;按照模板的参数要求&#xff0c;我们需要传入我们想要发送消息的内容给模板&#xff0c…

Jenkins+Docker+SpringCloud微服务持续集成

JenkinsDockerSpringCloud微服务持续集成 JenkinsDockerSpringCloud持续集成流程说明SpringCloud微服务源码概述本地运行微服务本地部署微服务 Docker安装和Dockerfile制作微服务镜像Harbor镜像仓库安装及使用在Harbor创建用户和项目上传镜像到Harbor从Harbor下载镜像 微服务持…

网络安全设备及部署

什么是等保定级&#xff1f; 之前了解了下等保定级&#xff0c;接下里做更加深入的探讨 文章目录 一、网路安全大事件1.1 震网病毒1.2 海康威视弱口令1.3 物联网Mirai病毒1.4 专网 黑天安 事件1.5 乌克兰停电1.6 委内瑞拉电网1.7 棱镜门事件1.8 熊猫烧香 二、法律法规解读三、安…

【AI】Python调用讯飞星火大模型接口,轻松实现文本生成

随着chatGPT的出现&#xff0c;通用大模型已经成为了研究的热点&#xff0c;由于众所周知的原因&#xff0c;亚太地区调用经常会被禁&#xff0c;在国内&#xff0c;讯飞星火大模型是一个非常优秀的中文预训练模型。本文将介绍如何使用Python调用讯飞星火大模型接口&#xff0c…

全球飞机电磁阀总体规模分析

电磁阀是一种液压管路的电磁装置&#xff0c;通过使用电流产生磁场&#xff0c;从而驱动螺线管&#xff0c;控制阀中流体的流动。电磁阀作为流体控制自动化系统的执行器之一&#xff0c;有着结构紧凑、尺寸小、重量轻、密封良好、维修简便和可靠性高、节能降耗的特点&#xff0…

SpringBoot 的事务及使用

一、事务的常识 1、事务四特性&#xff08;ACID&#xff09; A 原子性&#xff1a;事务是最小单元,不可再分隔的一个整体。C 一致性&#xff1a;事务中的方法要么同时成功,要么都不成功,要不都失败。I 隔离性&#xff1a;多个事务操作数据库中同一个记录或多个记录时,对事务进…

【高频面试题】多线程篇

文章目录 一、线程的基础知识1.线程与进程的区别2.并行和并发有什么区别&#xff1f;3.创建线程的方式有哪些&#xff1f;3.1.Runnable 和 Callable 有什么区别&#xff1f;3.2.run()和 start()有什么区别&#xff1f; 4.线程包括哪些状态&#xff0c;状态之间是如何变化的4.1.…

C++核心编程:C++中的引用

C中的引用 引用的基本语法 作用&#xff1a;给变量起别名 语法&#xff1a;数据类型 & 别名 原名 //比如给一个int变量a命名一个别名 b int &b a;b 20; cout<< a << endl;//a 20引用的注意事项 引用必须初始化 int &b;//错误的引用在初始化后&…

问道管理:沪指窄幅震荡跌0.18%,有色、汽车等板块走低

3日早盘&#xff0c;沪指盘中窄幅震动下探&#xff0c;创业板逆市上扬&#xff1b;两市半日成交不足5000亿元&#xff0c;北向资金净卖出超15亿元。 到午间收盘&#xff0c;沪指跌0.18%报3255.88点&#xff0c;深成指跌0.23%&#xff0c;创业板指涨0.2%&#xff1b;两市算计成交…

剑指 Offer 37. 序列化二叉树

文章目录 题目描述简化题目思路分析 题目描述 请实现两个函数&#xff0c;分别用来序列化和反序列化二叉树。 你需要设计一个算法来实现二叉树的序列化与反序列化。这里不限定你的序列 / 反序列化算法执行逻辑&#xff0c;你只需要保证一个二叉树可以被序列化为一个字符串并且将…

Spring AOP(AOP概念,组成成分,实现,原理)

目录 1. 什么是Spring AOP&#xff1f; 2. 为什么要用AOP&#xff1f; 3. AOP该怎么学习&#xff1f; 3.1 AOP的组成 &#xff08;1&#xff09;切面&#xff08;Aspect&#xff09; &#xff08;2&#xff09;连接点&#xff08;join point&#xff09; &#xff08;3&a…

SpringMVC概述、SpringMVC的工作流程、创建SpringMVC的项目

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaweb 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 Spring MVC入门 一、Spring MVC概述二、入门案例2.1导入Sp…

【go语言学习笔记】04 Go 语言工程管理

文章目录 一、质量保证1. 单元测试1.1 定义1.2 Go 语言的单元测试1.3 单元测试覆盖率 2. 基准测试2.1 定义2.2 Go 语言的基准测试2.3 计时方法2.4 内存统计2.5 并发基准测试2.6 基准测试实战 3. 特别注意 二、性能优化1. 代码规范检查1.1 定义1.2 golangci-lint1.2.1 安装1.2.2…