循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络)

1.1基本介绍

长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。LSTM的设计旨在解决传统RNN中遇到的长序列依赖问题,以更好地捕捉和处理序列数据中的长期依赖关系。

下面是LSTM的内部结构图

LSTM

LSTM为了改善梯度消失,引入了一种特殊的存储单元,该存储单元被设计用于存储和提取长期记忆。与传统的RNN不同,LSTM包含三个关键的门(gate)来控制信息的流动,这些门分别是遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。

LSTM的结构允许它有效地处理和学习序列中的长期依赖关系,这在许多任务中很有用,如自然语言处理、语音识别和时间序列预测。由于其能捕获长期记忆,LSTM成为深度学习中重要的组件之一。

1.2 主要组成部分和工作原理

首先我们先弄明白LSTM单元中的每个符号的含义。每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
图中元素的节点信息

以下是LSTM的主要组成部分和工作原理:

  1. 细胞状态(Cell State):
    细胞状态是LSTM网络的主要存储单元,用于存储和传递长期记忆。细胞状态在序列的每一步都会被更新。在LSTM中,细胞状态负责保留网络需要记住的信息,以便更好地处理长期依赖关系。在每个时间步,LSTM通过一系列的操作来更新细胞状态。这些操作包括遗忘门、输入门和输出门的计算。细胞状态在这些门的帮助下动态地保留和遗忘信息。
    细胞状态

  2. 遗忘门(Forget Gate):
    遗忘门决定哪些信息应该被遗忘,从而允许网络丢弃不重要的信息。它通过一个sigmoid激活函数生成一个介于0和1之间的值,用于控制细胞状态中信息的丢失程度。
    遗忘门的计算过程如下:
    2.1 输入:
    上一时刻的隐藏状态(或者是输入数据的向量)
    当前时刻的输入数据
    2.2 计算遗忘门的值:
    将上一时刻的隐藏状态和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层(通常称为遗忘门层)得到介于0和1之间的值。
    这个值表示细胞状态中哪些信息应该被保留(接近1),哪些信息应该被遗忘(接近0)。
    2.3 遗忘操作:
    将上一时刻的细胞状态与遗忘门的输出相乘,以决定保留哪些信息。
    2.4数学表达式如下:
    遗忘门的输出:
    遗忘门

其中:
W f 和 b f 是遗忘门的权重矩阵和偏置向量。 W_f 和 b_f是遗忘门的权重矩阵和偏置向量。 Wfbf是遗忘门的权重矩阵和偏置向量。
h t − 1 ​是上一时刻的隐藏状态。 h_{t−1}​ 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
x t 是当前时刻的输入数据。 x_t是当前时刻的输入数据。 xt是当前时刻的输入数据。
σ 是 s i g m o i d 激活函数。 σ 是sigmoid激活函数。 σsigmoid激活函数。

遗忘门的输出 ft 决定了细胞状态中上一时刻信息的保留程度。这个机制允许LSTM网络在处理时间序列数据时更有效地记住长期依赖关系。

  1. 输入门(Input Gate):
    输入门负责确定在当前时间步骤中要添加到细胞状态的新信息。类似于遗忘门,输入门使用sigmoid激活函数产生一个介于0和1之间的值,表示要保留多少新信息,并使用tanh激活函数生成一个新的候选值。
    在这里插入图片描述输入门的计算过程如下:
(1)输入门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示要保留的新信息的程度。
(2)生成新的候选值:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有tanh激活函数的全连接层得到一个新的候选值(介于-1和1之间)。
(3)更新细胞状态的操作:将输入门的输出与新的候选值相乘,得到要添加到细胞状态的新信息。
  1. 输出门(Output Gate):
    输出门(Output Gate)在LSTM中控制细胞在特定时间步上的输出。输出门使用sigmoid激活函数产生介于0和1之间的值,这个值决定了在当前时间步细胞状态中有多少信息被输出。同时,输出门的输出与细胞状态经过tanh激活函数后的值相乘,产生最终的LSTM输出。

输出门的计算过程如下:

输出门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示在当前时间步细胞状态中有多少信息要输出。
生成最终的LSTM输出:将当前时刻的细胞状态经过tanh激活函数,得到介于-1和1之间的值。将输出门的输出与tanh激活函数的细胞状态相乘,产生最终的LSTM输出。

在这里插入图片描述

1.3 LSTM的基础代码实现

以下是一个基础的实现,其中包括多层双向LSTM的前向传播。请注意,这个实现仍然是一个简化版本,实际应用中可能需要更多的调整和优化。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def lstm_cell(xt, a_prev, c_prev, parameters):# 从参数中提取权重和偏置Wf = parameters["Wf"]bf = parameters["bf"]Wi = parameters["Wi"]bi = parameters["bi"]Wo = parameters["Wo"]bo = parameters["bo"]Wc = parameters["Wc"]bc = parameters["bc"]# 合并输入和上一个时间步的隐藏状态concat = np.concatenate((a_prev, xt), axis=0)# 遗忘门ft = sigmoid(np.dot(Wf, concat) + bf)# 输入门it = sigmoid(np.dot(Wi, concat) + bi)# 更新细胞状态cct = tanh(np.dot(Wc, concat) + bc)c_next = ft * c_prev + it * cct# 输出门ot = sigmoid(np.dot(Wo, concat) + bo)# 更新隐藏状态a_next = ot * tanh(c_next)# 保存计算中间结果,以便反向传播cache = (xt, a_prev, c_prev, a_next, c_next, ft, it, ot, cct)return a_next, c_next, cachedef lstm_forward(x, a0, parameters):n_x, m, T_x = x.shapen_a = a0.shape[0]a = np.zeros((n_a, m, T_x))c = np.zeros_like(a)caches = []a_prev = a0c_prev = np.zeros_like(a_prev)for t in range(T_x):xt = x[:, :, t]a_next, c_next, cache = lstm_cell(xt, a_prev, c_prev, parameters)a[:,:,t] = a_nextc[:,:,t] = c_nextcaches.append(cache)a_prev = a_nextc_prev = c_nextreturn a, c, cachesdef lstm_model_forward(x, parameters):caches = []a = xc_list = []for layer in parameters:a, c, layer_cache = lstm_forward(a, np.zeros_like(a[:, :, 0]), layer)caches.append(layer_cache)c_list.append(c)return a, c_list, cachesdef dense_layer_forward(a, parameters):W = parameters["W"]b = parameters["b"]z = np.dot(W, a) + ba_next = sigmoid(z)return a_next, zdef model_forward(x, parameters_lstm, parameters_dense):a_lstm, c_list, caches_lstm = lstm_model_forward(x, parameters_lstm)a_dense = a_lstm[:, :, -1]z_dense_list = []for layer_dense in parameters_dense:a_dense, z_dense = dense_layer_forward(a_dense, layer_dense)z_dense_list.append(z_dense)return a_dense, c_list, caches_lstm, z_dense_list# 示例数据和参数
np.random.seed(1)
x = np.random.randn(10, 5, 3)  # 10个样本,每个样本5个时间步,每个时间步3个特征# LSTM参数
parameters_lstm = [{"Wf": np.random.randn(5, 8), "bf": np.random.randn(5, 1),"Wi": np.random.randn(5, 8), "bi": np.random.randn(5, 1),"Wo": np.random.randn(5, 8), "bo": np.random.randn(5, 1),"Wc": np.random.randn(5, 8), "bc": np.random.randn(5, 1)},{"Wf": np.random.randn(3, 8), "bf": np.random.randn(3, 1),"Wi": np.random.randn(3, 8), "bi": np.random.randn(3, 1),"Wo": np.random.randn(3, 8), "bo": np.random.randn(3, 1),"Wc": np.random.randn(3, 8), "bc": np.random.randn(3, 1)}
]# Dense层参数
parameters_dense = [{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)},{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)}
]# 进行正向传播
a_dense, c_list, caches_lstm, z_dense_list = model_forward(x, parameters_lstm, parameters_dense)# 打印输出形状
print("a_dense.shape:", a_dense.shape)

二.GRU(门控循环单元)

GRU

2.1 GRU的基本介绍

门控循环单元(GRU,Gated Recurrent Unit)是一种用于处理序列数据的循环神经网络(RNN)变体,旨在解决传统RNN中的梯度消失问题,并提供更好的长期依赖建模。GRU引入了门控机制,类似于LSTM,但相对于LSTM,GRU结构更加简单。

GRU包含两个门:更新门(Update Gate)和重置门(Reset Gate)。这两个门允许GRU网络决定在当前时间步更新细胞状态的程度以及如何利用先前的隐藏状态。

重置门(Reset Gate)的计算:

通过一个sigmoid激活函数计算重置门的输出。重置门决定了在当前时间步,应该忽略多少先前的隐藏状态信息。

更新门(Update Gate)的计算:

通过一个sigmoid激活函数计算更新门的输出。更新门决定了在当前时间步,应该保留多少先前的隐藏状态信息。

候选隐藏状态的计算:

通过tanh激活函数计算一个候选的隐藏状态。

新的隐藏状态的计算:

通过更新门和候选隐藏状态计算新的隐藏状态。

2.2 GRU的代码实现

以下是使用PyTorch库实现基本的门控循环单元(GRU)的代码。PyTorch提供了GRU的高级API,可以轻松实现和使用。下面是一个简单的例子:

import torch
import torch.nn as nn# 定义GRU模型
class SimpleGRU(nn.Module):def __init__(self, input_size, hidden_size):super(SimpleGRU, self).__init__()self.gru = nn.GRU(input_size, hidden_size)def forward(self, x, hidden=None):output, hidden = self.gru(x, hidden)return output, hidden# 示例数据和模型参数
input_size = 3
hidden_size = 5
seq_len = 1  # 序列长度
batch_size = 1# 创建GRU模型
gru_model = SimpleGRU(input_size, hidden_size)# 将输入数据转换为PyTorch的Tensor
x = torch.randn(seq_len, batch_size, input_size)# 前向传播
output, hidden = gru_model(x)# 打印输出形状
print("Output shape:", output.shape)
print("Hidden shape:", hidden.shape)

以下是使用NumPy库实现基本的门控循环单元(GRU)的代码。这个实现是一个简化版本,其中包含更新门和重置门的计算,以及候选隐藏状态和新的隐藏状态的计算。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def gru_cell(a_prev, x, parameters):# 从参数中提取权重和偏置W_r = parameters["W_r"]b_r = parameters["b_r"]W_z = parameters["W_z"]b_z = parameters["b_z"]W_a = parameters["W_a"]b_a = parameters["b_a"]# 计算重置门r_t = sigmoid(np.dot(W_r, np.concatenate([a_prev, x])) + b_r)# 计算更新门z_t = sigmoid(np.dot(W_z, np.concatenate([a_prev, x])) + b_z)# 计算候选隐藏状态tilde_a_t = tanh(np.dot(W_a, np.concatenate([r_t * a_prev, x])) + b_a)# 计算新的隐藏状态a_t = (1 - z_t) * a_prev + z_t * tilde_a_t# 保存计算中间结果,以便反向传播cache = (a_prev, x, r_t, z_t, tilde_a_t, a_t)return a_t, cache# 示例数据和参数
np.random.seed(1)
a_prev = np.random.randn(5, 1)  # 上一时刻的隐藏状态
x = np.random.randn(3, 1)  # 当前时刻的输入数据# GRU参数
parameters = {"W_r": np.random.randn(5, 8),"b_r": np.random.randn(5, 1),"W_z": np.random.randn(5, 8),"b_z": np.random.randn(5, 1),"W_a": np.random.randn(5, 8),"b_a": np.random.randn(5, 1)
}# 单个GRU单元的前向传播
a_t, cache = gru_cell(a_prev, x, parameters)# 打印输出形状
print("a_t.shape:", a_t.shape)

本文参考了以下链接:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/242994.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Wincoot消除拉式图里outlier的多种策略

在蛋白质中,氨基酸残基之间由肽键相连。 由于肽键的平面性,每个残基的构象都可以用两个扭转角来描述。而空间位阻的关系又导致这两个扭转角只能取有限的值,并用拉式图(Ramachandran plot)来表述它们的允许区域。 拉式图可被用来描述蛋白质结构模型的总体质量。 对于确定…

万物简单AIoT 端云一体实战案例学习 之 快速开始

学物联网,来万物简单IoT物联网!! 下图是本案的3步导学,每个步骤中实现的功能请参考图中的说明。 1、简介 物联网具有场景多且复杂、链路长且开发门槛高等特点,让很多想学习或正在学习物联网的学生或开发者有点不知所措,甚至直接就放弃了。    万物简单AIoT物联网教育…

(2)(2.1) Andruav Android Cellular(一)

文章目录 前言 1 Andruav 是什么? 2 Andruav入门 3 Andruav FPV 4 Andruav GCS App​​​​​​​ 前言 Andruav 是一个基于安卓的互联系统,它将安卓手机作为公司计算机,为你的无人机和遥控车增添先进功能。 1 Andruav 是什么&#xff…

接口自动化测试框架设计

文章目录 接口测试的定义接口测试的意义接口测试的测试用例设计接口测试的测试用例设计方法postman主要功能请求体分类JSON数据类型postman内置参数postman变量全局变量环境变量 postman断言JSON提取器正则表达式提取器Cookie提取器postman加密接口签名 接口自动化测试基础getp…

【5G 接口协议】N2接口协议NGAP(NG Application Protocol)介绍

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

Miracast手机高清投屏到电视(免费)

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl Miracast概述 Miracast是一种无线显示标准,它允许支持Miracast的设备之间通过Wi-Fi直接共享音频和视频内容,实现屏幕镜像或扩展显示。这意味着你可以…

【项目日记(三)】内存池的整体框架设计

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:项目日记-高并发内存池⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你做项目   🔝🔝 开发环境: Visual Studio 2022 项目日…

k8s的对外服务---ingress

service的作用体现在两个方面: 集群内部:不断追踪pod的变化。他会更新endpoint中的pod对象,基于pod的IP地址不断变化的一种服务发现机制。 集群外部:类似负载均衡器,把流量IP端口,不涉及转发url(http、htt…

如何本地部署虚VideoReTalking

环境: Win10专业版 VideoReTalking 问题描述: 如何本地部署虚VideoReTalking 解决方案: VideoReTalking是一个强大的开源AI对嘴型工具,它是我目前使用过的AI对嘴型工具中效果最好的一个!它是由西安电子科技大学、…

医学图像的数据增强技术 --- 切割-拼接数据增强(CS-DA)

医学图像的新型数据增强技术 CS-DA 核心思想自然图像和医学图像之间的关键差异CS-DA 步骤确定增强后的数据数量 代码复现 CS-DA 核心思想 论文链接:https://arxiv.org/ftp/arxiv/papers/2210/2210.09099.pdf 大多数用于医学分割的数据增强技术最初是在自然图像上开…

H5嵌入小程序适配方案

时间过去了两个多月,2024已经到来,又老了一岁。头发也掉了好多。在这两个月时间里都忙着写页面,感觉时间过去得很快。没有以前那么轻松了。也不是遇到了什么难点技术,而是接手了一个很烂得项目。能有多烂,一个页面发起…

Vue2移动端项目使用$router.go(-1)不生效问题记录

目录 1、this.$router.go(-1) 改成 this.$router.back() 2、存储 from.path,使用 this.$router.push 3、hash模式中使用h5新增的onhashchange事件做hack处理 4、this.$router.go(-1) 之前添加一个 replace 方法 问题背景 : 在 Vue2 的一个移动端开发…

tag 标签

tag 标签 在使用 Git 版本控制的过程中,会产生大量的版本。如果我们想对某些重要版本进行记录,就可以给仓库历史中的某一个commit 打上标签,用于标识。 在本章中,我们将会学习如何列出已有的标签、如何创建和删除新的标签、以及…

【动态规划】【广度优先搜索】【状态压缩】847 访问所有节点的最短路径

作者推荐 视频算法专题 本文涉及知识点 动态规划汇总 广度优先搜索 状态压缩 LeetCode847 访问所有节点的最短路径 存在一个由 n 个节点组成的无向连通图,图中的节点按从 0 到 n - 1 编号。 给你一个数组 graph 表示这个图。其中,graph[i] 是一个列…

数学建模学习笔记||层次分析法

评价类问题 解决评价类问题首先需要想到一下三个问题 我们评价的目标是什么我们为了达到这个目标有哪几种可行方案评价的准则或者说指标是什么 对于以上三个问题,我们可以根据题目中的背景材料,常识以及网上收集到的参考资料进行结合,从而筛…

问题:Feem无法发送信息OR无法连接(手机端无法发给电脑端)

目录 前言 问题分析 资源、链接 其他问题 前言 需要在小米手机、华为平板、Dell电脑之间传输文件,试过安装破解的华为电脑管家、小米的MIUI文件传输等,均无果。(小米“远程管理”ftp传输倒是可以,但速度太慢了,且…

js实现九九乘法表

效果图 代码 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title></head><body><script type"text/javascript">// 输出乘法口诀表// document.write () 空格 " " 换行…

java黑马学习笔记

数组 变量存在栈中&#xff0c;变量值存放在堆中。 数组反转 public class test{public static void main(String[] args){//目标&#xff1a;完成数组反转int[] arr {10,20,30,40,50};for (int i 0,j arr.length - 1;i < j;i,j--){int tep arr[j]; //后一个值赋给临时…

微前端-无界wujie

无界微前端方案基于 webcomponent 容器 iframe 沙箱&#xff0c;能够完善的解决适配成本、样式隔离、运行性能、页面白屏、子应用通信、子应用保活、多应用激活、vite 框架支持、应用共享等用户的核心诉求。 主项目安装无界 vue2项目&#xff1a;npm i wujie-vue2 -S vue3项目…

wayland(xdg_wm_base) + egl + opengles 最简实例

文章目录 前言一、ubuntu 下相关环境准备1. 获取 xdg_wm_base 依赖的相关文件2. 查看 ubuntu 上安装的opengles 版本3. 查看 weston 所支持的 窗口shell 接口种类二、xdg_wm_base 介绍三、egl_wayland_demo1.egl_wayland_demo2_0.c2.egl_wayland_demo3_0.c3. xdg-shell-protoco…