【动手学深度学习-Pytorch版】门控循环单元GRU

关于GRU的笔记

支持隐状态的门控:这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。 下面我们将详细讨论各类门控。

formula:门 是和隐藏状态同样的一个向量
重置门: R t = σ ( X t ∗ W x r + H t − 1 ∗ W h r + b r ) R_t = σ(X_t * W_{xr} + H_{t-1} * W_{hr} + b_r) Rt=σ(XtWxr+Ht1Whr+br)
更新门: Z t = σ ( X t ∗ W x z + H t − 1 ∗ W h z + b z ) Z_t = σ(X_t * W_{xz} + H_{t-1} * W_{hz} + b_z) Zt=σ(XtWxz+Ht1Whz+bz)

候选隐状态
☆ H t = t a n h ( X t ∗ W x h + ( R t ⊙ H t − 1 ) ∗ W h h + b h ) ^☆H_t = tanh(X_t * W_{xh} + (R_t ⊙ H_{t-1}) * W_{hh} + b_h) Ht=tanh(XtWxh+(RtHt1)Whh+bh)
当重置门的项接近于1时,就可以恢复到一个普通的循环神经网络RNN的模型对于重置门的项接近于0时,候选隐状态是以X_t作为输入的多层感知机的结果,它除去了隐状态H_t-1对当前的影响任何预先存在的隐状态都会被重置为默认值

隐状态
更新门Z_t仅需要在H_t-1和(star)H_t之间进行按元素的凸组合就可以实现这个目标。
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ ☆ H t H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ ^☆H_t Ht=ZtHt1+(1Zt)Ht
每当更新门{Z_t}接近于1时,模型就倾向于只保留旧状态,此时来自于 X t X_t Xt的信息基本上都会被忽略,当前的隐状态只依赖于上一次的 H ( t − 1 ) H_(t-1) H(t1)相反,当 Z t Z_t Zt接近于0时,新的隐状态 H t H_t Ht就会接近于候选隐状态 ☆ H t ^☆H_t Ht
优点:这些设计可以帮助我们处理循环神经网络中的梯度消失的问题,并且更好地捕获时间步距离很长的序列的依赖关系。例如:如果整个子序列的所有时间步的更新门都接近于1,则无论序列的长度如何,在序列起始时间步的旧隐状态都很容易保留并传递到序列的结束。
在这里插入图片描述
在这里插入图片描述

综上
重置门有助于捕获序列中的短期依赖关系;
更新门有助于捕获序列中的长期依赖关系。

GRU从零开始实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 定义模型
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),)# gru模型
def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:# 更新门:Z_t = σ(X_t * W_xz + H_T-1 * W_hz + bz)Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 重置门:R_t = σ(X_t * W_xr + H_t - 1 * W_hr + br)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)# (star)H_t = tanh(X_t * W_xh + (R_t ⊙ H_t-1) * W_hh + bh)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)# H_t = Z_t ⊙ H_t - 1 + (1 - Z_t) ⊙ (star)H_tH = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)# 训练与预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500,1
model = d2l.RNNModelScratch(len(vocab),num_hiddens,device,get_params,init_gru_state,gru)d2l.train_ch8(model,train_iter,vocab,lr,num_epochs,device)

在这里插入图片描述

简洁实现GRU

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/140881.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker文档阅读笔记-How to Commit Changes to a Docker Image with Examples

介绍 在工作中使用Docker镜像和容器,用得最多的就是如何提交修改过的Docker镜像。当提交修改后,就会在原有的镜像上创建一个新的镜像。 本博文说明如何提交一个新的Docker镜像。 前提 ①有一个可以直接访问服务器的运行终端; ②帐号需要r…

云计算安全:保护数字资产的前沿策略

文章目录 1. 云计算安全威胁1.1 数据泄露1.2 身份认证问题1.3 无法预测的网络攻击1.4 集中攻击 2. 云计算安全最佳实践2.1 身份和访问管理(IAM)2.2 数据加密2.3 安全审计和监控2.4 多重身份验证(MFA) 3. 安全自动化3.1 基础设施即…

springboot 获取参数

1.获取简单参数 2.实体对象参数

#倍增 #国旗计划

文章目录 题目&#xff1a;题解代码 题目&#xff1a; 国旗计划 题解 三个技巧&#xff1a; 断环成链&#xff1a; 具体而言就是&#xff1a; if(w[i].R < w[i].L) w[i].R m; m是环的长度&#xff1b; 贪心&#xff1a; 选择一个区间i后&#xff0c;下一个区间只能从左端…

第15篇ESP32 idf框架 wifi联网_WiFi AP模式_手机连接到esp32开发板

第1篇:Arduino与ESP32开发板的安装方法 第2篇:ESP32 helloword第一个程序示范点亮板载LED 第3篇:vscode搭建esp32 arduino开发环境 第4篇:vscodeplatformio搭建esp32 arduino开发环境 ​​​​​​第5篇:doit_esp32_devkit_v1使用pmw呼吸灯实验 第6篇:ESP32连接无源喇叭播…

电子电气架构——无感刷写(Vector)协议栈方案介绍

电子电气架构——无感刷写(Vector)协议栈方案介绍 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 对学习而言,学习之后的思考、思考之后的行动、行动之后的改变更重要,如果不盯住内层的改变量…

BIOMOD2模型、MaxEnt模型物种分布模拟,生物多样性生境模拟,论文写作

①基于R语言BIOMOD2模型的物种分布模拟实践技术应用 针对我国目前已有自然保护区普遍存在保护目标不明确、保护成效低下和保护空缺依然存在等问题&#xff0c;科学的鉴定生物多样性热点保护区域与保护空缺显得刻不容缓。 BIOMOD2提供运行多达10余种物种分布模拟模型&#xff0c…

【Ambari】银河麒麟V10 ARM64架构_安装Ambari2.7.6HDP3.3.1(HiDataPlus)

&#x1f341; 博主 "开着拖拉机回家"带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——&#x1f390;开着拖拉机回家_大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341; 希望本文能够给您带来一定的帮助&#x1f338;文…

接口测试练习步骤

在接触接口测试过程中补了很多课&#xff0c; 终于有点领悟接口测试的根本&#xff1b; 偶是个实用派&#xff5e;&#xff0c;那么现实中没有用的东西&#xff0c;基本上我都不会有很大的概念&#xff1b; 下面给的是接口测试的统一大步骤&#xff0c;其实就是让我们对接口…

长期用眼不再怕!NineData SQL 窗口支持深色模式

您有没有尝试过被明亮的显示器闪瞎眼的经历&#xff1f; 在夜间或低光环境下&#xff0c;明亮的界面会导致许多用眼健康问题&#xff0c;例如长时间使用导致的眼睛疲劳、干涩和不适感&#xff0c;同时夜间还可能会抑制褪黑素分泌&#xff0c;给您的睡眠质量带来影响。 这些问…

前端性能测试工具-lighthouse

Lighthouse简介 Lighthouse 是 Google 的一款开源工具&#xff0c;它可以作为一个 Chrome 扩展程序运行&#xff0c;或从命令行运行。只需要给 Lighthouse 提供一个要审查的网址&#xff0c;它将针对此页面运行一连串的测试&#xff0c;然后生成一个页面性能的报告。 Lightho…

计算机毕业设计 基于SSM+Vue的医院门诊互联电子病历管理信息系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

使用 Docker 安装 Elasticsearch (本地环境 M1 Mac)

Elasticsearchkibana下载安装 docker pull elasticsearch:7.16.2docker run --name es -d -e ES_JAVA_OPTS“-Xms512m -Xmx512m” -e “discovery.typesingle-node” -p 9200:9200 -p 9300:9300 elasticsearch:7.16.2docker pull kibana:7.16.2docker run --name kibana -e EL…

Ubuntu 20.04中docker-compose部署Nightingale

lsb_release -r可以看到操作系统版本是20.04&#xff0c;uname -r可以看到内核版本是5.5.19。 sudo apt install -y docker-compose安装docker-compose。 完成之后如下图&#xff1a; cd /opt/n9e/docker/进入到/opt/n9e/docker/里边。 docker-compose up -d进行部署。 …

华为OD机试 - 最小传输时延 - 深度优先搜索DFS(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明计算源节点1到目的节点5&#xff0c;符合要求的时延集合 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&…

牛客java训练题 day1

9.24 day1 Q 1. this 指针是用来干什么的&#xff1f; 2.基类和派生类分别是指什么&#xff1f; 3.为什么方法中不能写静态变量 4. 解释一下ASCII码和ANSI码和两者的区别 5.简述j ava.io java.sql java.awt java.rmi 分别是什么类型的包 6. 看下面一段代码&#xff1a;…

软考高级+系统架构设计师教程+第二版新版+电子版pdf

注意&#xff01;&#xff01;&#xff01; 系统架构设计师出新版教程啦&#xff0c;2022年11月出版。所以今年下半年是新版第一次考试&#xff0c;不要再复习老版教程了&#xff0c;内容改动挺大的。 【内容简介】系统架构设计师教程&#xff08;第2版&#xff09;作为全国计…

字符函数和字符串函数(C语言进阶)

字符函数和字符串函数 一.求字符串长度1.strlen 二.长度不受限制的字符串函数介绍1.strcpy2.strcat3.strcmp 前言 C语言中对字符和字符串的处理很是频繁&#xff0c;但是C语言本身是没有字符串类型的&#xff0c;字符串通常放在常量字符串中或者字符数组中。 字符串常量适用于那…

1*1的卷积核如何实现降维/升维?

在众多网络中&#xff0c;1*1的卷积核被引入用来实现输入数据通道数的改变。 举例说明&#xff0c;如果输入数据格式为X*Y*6&#xff0c;X*Y为数据矩阵&#xff0c;6为通道数&#xff0c;如果希望输出数据格式为X*Y*5&#xff0c;使用5个1*1*6的卷积核即可。 其转换过程类似于…

如何在32位MCU用printf()函数打印64位数据

1. 在32位MCU上定义64位变量&#xff1a; unsigned long long time_base; unsigned long long temp_time;2. 调用打印函数&#xff1a; printf("RFID:time_base:%d\r\n",time_base); printf("RFID:temp_time:%d\r\n",temp_time); printf("RFID:Ru…