【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/291429.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

目标检测评价标准

主要借鉴:https://github.com/rafaelpadilla/Object-Detection-Metrics?tabreadme-ov-file 主要评价指标、术语: Intersection Over Union (IOU):两个检测框交集面积与并集面积的比值 True Positive (TP):IOU大于阈值的检测框…

uniapp实现列表动态添加

1.效果图&#xff1a; 2.代码实现&#xff1a; 这里没有用uniapp提供的uni-list控件 <template> <view id"app"> <!-- 这里为了让标题&#xff08;h&#xff09;居中展示&#xff0c;给h标签设置了父标签&#xff0c;并设置父标签text-…

物联网实战--入门篇之(四)嵌入式-UART驱动

目录 一、串口简介 二、串口驱动设计 三、串口发送 四、串口接收处理 五、PM2.5数据接收处理 六、printf重定义 七、总结 一、串口简介 串口在单片机的开发中属于非常常用的外设&#xff0c;最基本的都会预留一个调试串口用来输出调试信息&#xff0c;串口时序这里就不谈…

既有理论深度又有技术细节——深度学习计算机视觉

推荐序 我曾经试图找到一本既有理论深度、知识广度&#xff0c;又有技术细节、数学原理的关于深度学习的书籍&#xff0c;供自己学习&#xff0c;也推荐给我的学生学习。虽浏览文献无数&#xff0c;但一直没有心仪的目标。两周前&#xff0c;刘升容女士将她的译作《深度学习计…

java中的单例模式

一、描述 单例模式就是程序中一个类只能有一个对象实例 举个例子: //引出单例模式&#xff0c;一个类中只能由一个对象实例 public class Singleton1 {private static Singleton1 instance new Singleton1();//通过这个方法来获取实例public static Singleton1 getInstance…

Verilog语法回顾--门级和开关级模型

目录 门和开关的声明 门和开关类型 支持驱动强度的门 延迟 实例数组 and&#xff0c;nand&#xff0c;nor&#xff0c;or&#xff0c;xor&#xff0c;xnor buf&#xff0c;not bufif1&#xff0c;bufif0&#xff0c;notif1&#xff0c;notif0 MOS switches Bidirecti…

TSINGSEE青犀智慧工厂视频汇聚与安全风险智能识别和预警方案

在智慧工厂的建设中&#xff0c;智能视频监控方案扮演着至关重要的角色。它不仅能够实现全方位、无死角的监控&#xff0c;还能够通过人工智能技术&#xff0c;实现智能识别、预警和分析&#xff0c;为工厂的安全生产和高效运营提供有力保障。 TSINGSEE青犀智慧工厂智能视频监…

公司官网怎么才会被百度收录

在互联网时代&#xff0c;公司官网是企业展示自身形象、产品与服务的重要窗口。然而&#xff0c;即使拥有精美的官网&#xff0c;如果不被搜索引擎收录&#xff0c;就无法被用户发现。本文将介绍公司官网如何被百度收录的一些方法和步骤。 1. 创建和提交网站地图 创建网站地图…

C语言例3-5:阅读下列程序,写出程序运行的结果。

代码如下&#xff1a; #include <stdio.h> int main(void) {int i1,s3;do{si;if(s%70) continue;else i;}while(s<15);printf("%d",i);return 0; } 结果如下&#xff1a; 分析&#xff1a; s314437741111617i3468

四、e2studio VS STM32CubeIDE之STM32CubeIDE线程安全解决方案

目录 一、概述/目的 二、原因和办法 三、线程安全问题的描述 四、STM32解决方案 4.1 通用策略 4.2 RTOS策略 4.3 策略的讲解 4.3.1 裸机应用(策略2、3) 4.3.2 RTOS应用(策略4、5) 五、关键源码 四、e2studio VS STM32CubeIDE之STM32CubeIDE线程安全解决方案 一、概述…

Spring Boot简介及案例

文章目录 Spring Boot简介以下是一个简单的 Spring Boot Web 应用实例**步骤 1&#xff1a;创建 Spring Boot 项目****步骤 2&#xff1a;编写 RESTful 控制器****步骤 3&#xff1a;配置主类****步骤 4&#xff1a;运行并测试应用** Spring Boot简介 Spring Boot 是一个用于简…

怎么让ChatGPT批量写作原创文章

随着人工智能技术的不断发展&#xff0c;自然语言处理模型在文本生成领域的应用也日益广泛。ChatGPT作为其中的佼佼者之一&#xff0c;凭借其强大的文本生成能力和智能对话特性&#xff0c;为用户提供了一种高效、便捷的批量产出内容的解决方案。以下将就ChatGPT批量写作内容进…

【AI】命令行调用大模型

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 【AI】命令行调用大模型引入正文初始化项目撰写脚本全局安装 成果展示 【AI】命令…

Ubuntu20.04LTS+uhd3.15+gnuradio3.8.1源码编译及安装

文章目录 前言一、卸载本地 gnuradio二、安装 UHD 驱动三、编译及安装 gnuradio四、验证 前言 本地 Ubuntu 环境的 gnuradio 是按照官方指导使用 ppa 的方式安装 uhd 和 gnuradio 的&#xff0c;也是最方便的方法&#xff0c;但是存在着一个问题&#xff0c;就是我无法修改底层…

亚信安全联合人保财险推出数字安全保障险方案,双重保障企业数字化转型

数字化发展&#xff0c;新兴技术的应用与落地带来网络攻击的进一步演进升级&#xff0c;同时全球产业链供应链融合协同的不断加深&#xff0c;更让网络威胁的影响范围与危害程度不断加剧。 企业单纯依靠自身安全能力建设&#xff0c;能否跟上网络威胁的进化速度&#xff1f;能否…

Day49:WEB攻防-文件上传存储安全OSS对象分站解析安全解码还原目录执行

目录 文件-解析方案-目录执行权限&解码还原 目录执行权限 解码还原 文件-存储方案-分站存储&OSS对象 分站存储 OSS对象存储 知识点&#xff1a; 1、文件上传-安全解析方案-目录权限&解码还原 2、文件上传-安全存储方案-分站存储&OSS对象 文件-解析方案-目…

【深耕 Python】Data Science with Python 数据科学(2)jupyter-lab和numpy数组

关于数据科学环境的建立&#xff0c;可以参考我的博客&#xff1a;【深耕 Python】Data Science with Python 数据科学&#xff08;1&#xff09;环境搭建 Jupyter代码片段1&#xff1a;简单数组的定义和排序 import numpy as np np.array([1, 2, 3]) a np.array([9, 6, 2, …

深入解析快速排序算法

深入解析快速排序算法 一、快速排序算法简介二、快速排序算法过程三、快速排序算法示例四、快速排序算法分析1. 时间复杂度&#xff1a;2. 空间复杂度&#xff1a;3. 稳定性&#xff1a; 五、快速排序算法优化1. 优化基准元素的选择&#xff1a;2. 优化小数组的排序&#xff1a…

WIFI驱动移植实验:WIFI从路由器动态获取IP地址与联网

一. 简介 前面两篇文章&#xff0c;一篇文章实现了WIFI联网前要做的工作&#xff0c;另一篇文章配置了WIFI配置文件&#xff0c;进行了WIFI热点的连接。文章如下&#xff1a; WIFI驱动移植实验&#xff1a;WIFI 联网前的工作-CSDN博客 WIFI驱动移植实验&#xff1a;连接WIF…

工业镜头常用参数之实效F(Fno.)和像圈

Fno. 工业镜头中常用到的参数F&#xff0c;有时候用F/#&#xff0c;Fno.来表示&#xff0c;指的是镜头通光能力的参数。它可用镜头焦距及入瞳直径来表示&#xff0c;也可通过镜头数值孔径&#xff08;NA&#xff09;和光学放大倍率&#xff08;β&#xff09;来计算。有效Fno.…