58、深度学习-自学之路-自己搭建深度学习框架-19、RNN神经网络梯度消失和爆炸的原因(从公式推导方向来说明),通过RNN的前向传播和反向传播公式来理解。

一、RNN神经网络的前向传播图如下:

时间步 t=1:
x₁ → (W_x) → [RNN Cell] → h₁ → (W_y) → y₁
           ↑ (W_h)
          h₀ (初始隐藏状态)

时间步 t=2:
x₂ → (W_x) → [RNN Cell] → h₂ → (W_y) → y₂
           ↑ (W_h)
          h₁

时间步 t=3:
x₃ → (W_x) → [RNN Cell] → h₃ → (W_y) → y₃
           ↑ (W_h)
          h₂

过程解释

时间步 t=1:

  1. 输入

    • 输入 x₁ 是第一个时间步的输入数据(例如,一个词向量或时间序列数据点)。

  2. 权重作用

    • 输入 x₁ 通过权重矩阵 W_x 进行线性变换:

                W_x · x₁

               2. 初始隐藏状态 h₀ 通过权重矩阵 W_h 进行线性变换:

                W_h · h₀        

   3.RNN Cell 计算

  • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ(如 tanh 或 ReLU):

        h₁ = σ(W_h · h₀ + W_x · x₁ + b_h)    

  h₁ 是第一个时间步的隐藏状态,包含了当前输入 x₁ 和前一个隐藏状态 h₀ 的信息。

4.输出计算

  • 隐藏状态 h₁ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

        y₁ = σ(W_y · h₁ + b_y)

  • y₁ 是第一个时间步的输出(例如,预测的下一个词或时间序列值)。

时间步 t=2:

x₂ → (W_x) → [RNN Cell] → h₂ → (W_y) → y₂
           ↑ (W_h)
          h₁

过程解释

  1. 输入

    • 输入 x₂ 是第二个时间步的输入数据。

  2. 权重作用

    • 输入 x₂ 通过权重矩阵 W_x 进行线性变换:

      W_x · x₂

    • 前一个隐藏状态 h₁ 通过权重矩阵 W_h 进行线性变换:

      W_h · h₁

  3. RNN Cell 计算

    • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ

      h₂ = σ(W_h · h₁ + W_x · x₂ + b_h)

    • h₂ 是第二个时间步的隐藏状态,包含了当前输入 x₂ 和前一个隐藏状态 h₁ 的信息。

  4. 输出计算

    • 隐藏状态 h₂ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

      y₂ = σ(W_y · h₂ + b_y)

    • y₂ 是第二个时间步的输出。

时间步 t=3

复制

x₃ → (W_x) → [RNN Cell] → h₃ → (W_y) → y₃↑ (W_h)h₂
过程解释
  1. 输入

    • 输入 x₃ 是第三个时间步的输入数据。

  2. 权重作用

    • 输入 x₃ 通过权重矩阵 W_x 进行线性变换:

      W_x · x₃

    • 前一个隐藏状态 h₂ 通过权重矩阵 W_h 进行线性变换:

      W_h · h₂

  3. RNN Cell 计算

    • 将变换后的输入和隐藏状态相加,并加上偏置项 b_h,然后通过激活函数 σ

      h₃ = σ(W_h · h₂ + W_x · x₃ + b_h)

    • h₃ 是第三个时间步的隐藏状态,包含了当前输入 x₃ 和前一个隐藏状态 h₂ 的信息。

  4. 输出计算

    • 隐藏状态 h₃ 通过权重矩阵 W_y 进行线性变换,并加上偏置项 b_y,然后通过激活函数 σ

      y₃ = σ(W_y · h₃ + b_y)

    • y₃ 是第三个时间步的输出。

通过上面的公式的观察,大家可以看到一个问题就是:

一共有3个时间步,也就是信息向前传播了三次,然后每次传播使用的输入的权重层是同一个权重、隐藏层对应的权重层是是同一个权重值。

第一个神经网络输出的隐藏层h1给到第二个神经网络,

第二个神经网络的输出隐藏层h2给到第三个神经网络。

第三个神经网络的输出隐藏层h3应该会给到第四个神经网络,如果有的话。

然后每一个神经网络都会有一个预测值y1 y2 y3

如果我们的输出方式是多输入多输出,那么我们每一个预测值y1 y2 y3都会对应一个真实值ture1、ture2、ture3

然后对应着三个误差值loss1 loss2 loss3,然后把loss1 + loss2 + loss3 =L

这个就是前向传播的过程。

二、RNN神经网络的反向传播:

反向传播我们从T=3开始往后传播:

首先:y₃ = σ(W_y · h₃ + b_y)

瞬时函数为L

(1)计算输出层的梯度

损失函数对输出 y_t 的梯度:这个里面的t 你可以认为是3,方便你理解。

        

损失函数对隐藏状态 h_t 的梯度:h_3方便你理解。

然后我们考虑一下,从L 到 h_3所经过的路线:

 L  --->  y3  --->h_3

上式子中:

:t=3

              :t=3

                   :T=3,前面传递的三个预测值和真实值之间的误差值之和

从L到h3 经过了L  --->  y3  --->h_3这个路径。

所以L 对h_3的求导为:

其实我们求导到最后应该要要对权重的求导,因为最后要通过修改权重来学习内容。

也就是:

第一个是误差对输入层权重的导数

第二个是误差对隐藏层权重的导数

第三个是误差对预测层权重的导数

然后我们先看一下误差对输入层权重的求导,

现在我们考虑一下L到Wx的路径有哪些,

第一条是:L --->y3 --->h3 --->Wx

第二条是:L --->y3 --->h3 --->h2 --->Wx

第三条是:L --->y3 --->h3 --->h2 --->h1 --->Wx

然后把第一条路径的导数加上第二条路径的导数再加上第三条路径的导数就是L对Wx的求导。

这个是我自己推导的,可能有些地方不够严谨,但是具体的过程是正确的,从最后的公式我们可以看到距离Y3 最远的X1的前面的值是激活函数的导数的三次方乘以隐藏层权重Wh的平方,那么如果权重的值远小于1,平方后再乘以激活函数的3次方肯定已经远小于1,非常接近0了,那么最后由x1能给L对Wx的导数的值提供的影响就大大减小了,那么x1对L的影响就大大减小了,那么就导致了梯度的消失。对x1信息的遗忘。所有RNN不能够处理很长文本的原因。

梯度爆炸是因为,如果Wh很大比如说是100,平方就是10000,激活函数是0.2,最后的结果就是

80,那么就是说x1对整体的影响可以达到80这么多。如果再传递一层,就更大了。这就导致了梯度爆炸,误差难以收敛。

还有L对Wh的导数

L对Wy的导数都是同样推导的。

不知道大家能不能理解。可以自己动手推导一下。然后就好理解了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/26218.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Qt-信号与槽】connect函数的用法

🏠个人主页:Yui_ 🍑操作环境:Qt Creator 🚀所属专栏:Qt 文章目录 1.信号和槽的概念1.1 信号的本质1.2 槽的本质1.3 补充说明2. 信号和槽的使用2.1 connect函数介绍2.2 connect函数的简单使用2.2.1 图形化方…

ESP32+Mixly-WiFi

#include <WiFi.h> #include <TimeLib.h> #include <NtpClientLib.h>int8_t timeZone 8; // 时区设置&#xff0c;东八区为8 const PROGMEM char *ntpServer "ntp1.aliyun.com"; // NTP服务器地址void setup(){Serial.begin(9600); //初始化串口…

Python 数据可视化(一)熟悉Matplotlib

目录 一、安装包 二、先画个折线图 1、修改标签文字和线条粗细 2、内置样式 3、scatter() 绘制散点图 4、scatter() 绘制多个点 5、设置样式 6、保存绘图 数据可视化指的是通过可视化表示来探索和呈现数据集内的规律。 一、安装包 win R 打开终端 安装 Matplotlib&…

目标检测——数据处理

1. Mosaic 数据增强 Mosaic 数据增强步骤: (1). 选择四个图像&#xff1a; 从数据集中随机选择四张图像。这四张图像是用来组合成一个新图像的基础。 (2) 确定拼接位置&#xff1a; 设计一个新的画布(输入size的2倍)&#xff0c;在指定范围内找出一个随机点&#xff08;如…

从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(六) 导航栏 和 个人信息设置

1.导航栏&#xff08;navbar&#xff09; 在components下面 创建NavBar.jsx import { MessageSquare,Settings,User,LogOut} from "lucide-react" import {Link} from "react-router-dom" import { useAuthStore } from "../store/useAuthStore&qu…

如何通过 LlamaIndex 将数据导入 Elasticsearch

作者&#xff1a;来自 Elastic Andre Luiz 逐步介绍如何使用 RAG 和 LlamaIndex 提取数据并进行搜索。 在本文中&#xff0c;我们将使用 LlamaIndex 来索引数据&#xff0c;从而实现一个常见问题搜索引擎。 Elasticsearch 将作为我们的向量数据库&#xff0c;实现向量搜索&am…

yunedit-post ,api测试比postman更好

postman应该是大家最熟悉的api测试软件了&#xff0c;但是由于它是外国软件&#xff0c;使用它的高端功能注册和缴费都比较麻烦。生成在线文档分享也经常无法访问被拦截掉。 这里可以推荐一下yunedit-post&#xff0c;该有的功能都有。 https://www.yunedit.com/postdetail …

Gopeed 各种类型的文件资源下载器 v1.6.7 中文版

Gopeed 是一款由 Go 和 Flutter 开发的下载器。它具有简洁美观的界面以及强大的功能&#xff0c;支持 HTTP、BitTorrent、Magnet 等协议&#xff0c;并且可以在全平台上使用。 开发语言及技术&#xff1a;Gopeed 采用 Go 和 Flutter 进行开发。Go 语言具有高效、简洁的特点&am…

3d投影到2d python opencv

目录 cv2.projectPoints 投影 矩阵计算投影 cv2.projectPoints 投影 cv2.projectPoints() 是 OpenCV 中的一个函数&#xff0c;用于将三维空间中的点&#xff08;3D points&#xff09;投影到二维图像平面上。这在计算机视觉中经常用于相机标定、物体姿态估计、3D物体与2D图…

Linux操作系统5-进程信号3(信号产生总结与核心转储)

上篇文章&#xff1a;Linux操作系统5-进程信号2&#xff08;信号的4种产生方式&#xff0c;signal系统调用&#xff09;-CSDN博客 本篇Gitee仓库&#xff1a;myLerningCode/l25 橘子真甜/Linux操作系统与网络编程学习 - 码云 - 开源中国 (gitee.com) 本篇重点&#xff1a;核心…

Linux《基础开发工具(上)》

在之前的篇章当中我们已经了解了Linux当中基本的指令以及相关的知识&#xff0c;那么接下来在本篇当中就开始学基本的开发工具&#xff0c;在此我们一共要了解6大开发工具&#xff0c;在此将这些工具的学习分为上中下篇&#xff0c;在本篇当中我们首先要来学习的是yun以及vim,一…

kali liux的下载

Kali Linux | Penetration Testing and Ethical Hacking Linux Distributionhttps://www.kali.org/ VMware虚拟机https://pan.quark.cn/s/aa869ffbf184 【补充一个今天学到的知识昂和内容无关:&#xff08;遥感&#xff09;指非接触的远距离探测技术&#xff0c;使用传感器探…

微软AI900认证备考全攻略:开启AI职业进阶之路

在当今数字化时代&#xff0c;人工智能&#xff08;AI&#xff09;正深刻地改变着我们的工作和生活。微软AI900认证作为AI领域的权威认证之一&#xff0c;不仅为技术爱好者提供了深入探索AI的机会&#xff0c;更是开启AI职业进阶之路的重要敲门砖。以下是一份全面的备考攻略&am…

【Mark】记录用宝塔+Nginx+worldpress+域名遇到的跨域,301,127.0.0.1,CSS加载失败问题

背景 想要用宝塔搭建worldpress&#xff0c;然后用域名直接转https&#xff0c;隐藏掉ipport。 结果被折磨了1天&#xff0c;一直在死活在301&#xff0c;127.0.0.1打转 还有css加载不了的情况 因为worldpress很多是301重定向的&#xff0c;所以改到最后我都不知道改了什么&am…

算法题001——移动零

移动零 力扣——移动零点击链接即可跳转 这道题的数组被划分为两个区间&#xff0c;前一个区间为 非零元素&#xff0c;而后一个指针是 零元素 我们运用双指针&#xff0c;先定义两个指针&#xff0c;分别为 dest 和 cur , cur用来遍历整个数组&#xff0c;而 dest 表示我们…

Selenium自动化测试:如何搭建自动化测试环境,搭建环境过程应该注意的问题

最近也有很多人私下问我&#xff0c;selenium学习难吗&#xff0c;基础入门的学习内容很多是3以前的版本资料&#xff0c;对于有基础的人来说&#xff0c;3到4的差别虽然有&#xff0c;但是不足以影响自己&#xff0c;但是对于没有学过的人来说&#xff0c;通过资料再到自己写的…

mysql 全方位安装教程

下载 MySQL 【官网下载地址】 注意要选择较大的哪个安装包&#xff0c;小的安装包是一个安装器。 我们不用登录&#xff0c;直接下载 直接运行下载好的安装包 MySQL如果是 安装包安装, 可以图形化界面自主配置 如果是压缩包解压, 可以配置 配置文件, 可以解压安装到指定的…

深入刨析 之C++ string类

欢迎来到干货小仓库&#xff01;&#xff01;&#xff01; 没有完美的计划&#xff0c;每个人都在试验的过程中渐渐清晰!!! 1.标准库的string类 a. string是表示字符串的字符串类。 b. 该类的接口与常规容器的接口基本相同&#xff0c;再添加了一些专门用来操作string的常规操…

【AI论文】MedVLM-R1:通过强化学习激励视觉语言模型(VLMs)的医疗推理能力

摘要&#xff1a;推理是推进医学影像分析的关键前沿领域&#xff0c;其中透明度和可信度对于赢得临床医生信任和获得监管批准起着核心作用。尽管医学视觉语言模型&#xff08;VLMs&#xff09;在放射学任务中展现出巨大潜力&#xff0c;但大多数现有VLM仅给出最终答案&#xff…

深入理解并实现自定义 unordered_map 和 unordered_set

亲爱的读者朋友们&#x1f603;&#xff0c;此文开启知识盛宴与思想碰撞&#x1f389;。 快来参与讨论&#x1f4ac;&#xff0c;点赞&#x1f44d;、收藏⭐、分享&#x1f4e4;&#xff0c;共创活力社区。 在 C 的标准模板库&#xff08;STL&#xff09;中&#xff0c;unorder…