深度学习--------------------------------门控循环单元GRU

目录

  • 候选隐状态
  • 隐状态
  • 门控循环单元GRU从零开始实现代码
    • 初始化模型参数
    • 定义隐藏状态的初始化函数
    • 定义门控循环单元模型
    • 训练
    • 该部分总代码
    • 简洁代码实现

做RNN的时候处理不了太长的序列,这是因为把整个序列信息全部放在隐藏状态里面,当时间很长的话,隐藏状态可能就会累计很多东西,所以对于前面很久以前的信息不易从中抽取出来了。

R t R_t Rt就是重置, Z t Z_t Zt就是更新
门是跟隐藏状态同样长度的一个向量,计算方式跟RNN的隐藏状态是一样的。
在这里插入图片描述

在这里插入图片描述




候选隐状态

在这里插入图片描述

在这里插入图片描述
假设 R t R_t Rt里面的元素靠近零的话,那么 R t R_t Rt点乘 H t − 1 H_{t-1} Ht1就会变得像零。(就等于是把上一个时刻的隐藏状态忘掉。)
如果全部设成0就变成了初始状态,等于这个时刻开始前面的信息全部不要。
如果全部设成1,就表示所有前面的信息全部拿过来做当前的更新。




隐状态

在这里插入图片描述

H t H_t Ht等于 Z t Z_t Zt按元素点乘上一次的隐藏状态+(1- Z t Z_t Zt)按元素点乘候选隐藏状态

Z t Z_t Zt是一个控制单元,叫做update gate。它是在0-1之间的数字。
假设 Z t Z_t Zt都等于1。(就是不更新过去的状态,把过去的状态放到现在)

在这里插入图片描述
假设 Z t Z_t Zt都等于0。(不直接拿过去的状态了,基本上看现在的更新状态)

Z t Z_t Zt里面全0,且 R t R_t Rt里面全1的时候就回到我们RNN的情况下。




门控循环单元GRU从零开始实现代码

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 3
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # GRU多了这两行,更新门的权重和偏置W_xr, W_hr, b_r = three()  # GRU多了这两行,重置门的权重和偏置W_xh, W_hh, b_h = three()  # 候选隐藏状态的权重和偏置# 隐藏状态到输出的权重W_hq = normal((num_hiddens, num_outputs))# 输出的偏置b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]# 遍历参数列表中所有参数for param in params:param.requires_grad_(True)return params



定义隐藏状态的初始化函数

定义隐状态的初始化函数init_gru_state。与之前定义的init_rnn_state函数一样,此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )



定义门控循环单元模型

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)



训练

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)



该部分总代码

import torch
from torch import nn
from d2l import torch as d2l# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个函数,生成三组权重和偏置张量,用于不同的门控机制def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 初始化GRU中的权重和偏置# 权重和偏置用于控制更新门W_xz, W_hz, b_z = three()  # GRU多了这两行# 权重和偏置用于控制重置门W_xr, W_hr, b_r = three()  # GRU多了这两行W_xh, W_hh, b_h = three()W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 定义隐藏状态的初始化函数
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),)# 定义门控循环单元模型
def gru(inputs, state, params):# 参数 params 解包为多个变量,分别表示模型中的权重和偏置W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# 遍历输入序列中的每个时间步for X in inputs:# 更新门控机制 ZZ = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 重置门控机制 RR = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)# 将所有输出拼接在一起,并返回拼接后的结果和最终的隐藏状态return torch.cat(outputs, dim=0), (H,)batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述




简洁代码实现

from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/440082.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jmeter操作数据库

jmeter操作数据库 一、打开数据库 二、jmeter下载驱动,安装jdbc驱动 1、下载好的驱动包 2、将驱动包复制粘贴 存放在包的路径下 (1)jdk下面 a、路径:jdk1\jre\lib b、jdk1\jre\lib\ext (2)jmeter下 a、…

SpringIoC容器的初识

一、SpringIoC容器的介绍 Spring IoC 容器,负责实例化、配置和组装 bean(组件)。容器通过读取配置元数据来获取有关要实例化、配置和组装组件的指令。配置元数据以 XML、Java 注解或 Java 代码形式表现。它允许表达组成应用程序的组件以及这…

基于依赖注入技术的.net core WebApi框架创建实例

依赖注入(Dependency Injection, DI)是一种软件设计模式,用于实现控制反转(Inversion of Control, IoC)。在ASP.NET Core中,依赖注入是内置的核心功能之一。它允许你将应用程序的组件解耦和配置&#xff0c…

Linux:进程入门(进程与程序的区别,进程的标识符,fork函数创建多进程)

往期文章:《Linux:深入了解冯诺依曼结构与操作系统》 Linux:深入理解冯诺依曼结构与操作系统-CSDN博客 目录 1. 概念 2. 描述进程 3. 深入理解进程的本质 4. 进程PID 4.1 指令获取PID 4.2 geipid函数获取PID 4.3 kill指令终止进程 …

Linux驱动开发(速记版)--GPIO子系统

第105章 GPIO 入门 105.1 GPIO 引脚分布 RK3568 有 5 组 GPIO:GPIO0 到 GPIO4。 每组 GPIO 又以 A0 到 A7,B0 到 B7,C0 到C7,D0 到 D7,作为区分的编号。 所以 RK3568 上的 GPIO 是不是应该有 5*4*8160 个呢&#xff1…

MySQL高阶2004-职员招聘人数

目录 题目 准备数据 分析数据 实现 题目 一家公司想雇佣新员工。公司的工资预算是 70000 美元。公司的招聘标准是: 雇佣最多的高级员工。在雇佣最多的高级员工后,使用剩余预算雇佣最多的初级员工。 编写一个SQL查询,查找根据上述标准雇…

男单新老对决:林诗栋VS马龙,巅峰之战

听闻了那场激动人心的新老对决,不禁让人热血沸腾。在这场乒乓球的巅峰之战中,林诗栋与马龙的对决无疑是一场视觉与技术的盛宴。 3:3的决胜局,两位选手的每一次挥拍都充满了策略与智慧,他们的每一次得分都让人心跳加速。 林诗栋&am…

Linux自动化构建工具Make/Makefile

make是一个命令 makefile是一个文件 touch 创建并用vim打开makefile 写入依赖对象和依赖方法 mycode是目标文件 第二行数依赖方法 以tab键开头 make makefile原理 makefile中写的是依赖关系和依赖方法 clean英语清理文件 后不用加源文件。.PHONY定义clean是伪目标。 make只…

动态SLAM总结二

文章目录 Mapping the Static Parts of Dynamic Scenes from 3D LiDAR Point Clouds Exploiting Ground Segmentation:(2021)RF-LIO:(2022)RH-Map:(2023)Mapless Online …

[C++]使用纯opencv部署yolov11-pose姿态估计onnx模型

【算法介绍】 使用纯OpenCV部署YOLOv11-Pose姿态估计ONNX模型是一项具有挑战性的任务,因为YOLOv11通常是用PyTorch等深度学习框架实现的,而OpenCV本身并不直接支持加载和运行PyTorch模型。然而,可以通过一些间接的方法来实现这一目标&#x…

POLYGON Nature - Low Poly 3D Art by Synty 树木植物

一个低多边形资源包,包含可以添加到现有多边形风格游戏中的树木、植物、地形、岩石、道具和特效 FX 资源。 为 POLYGON 系列提供混合样式树这一新增功能。弥合 POLYGON 与更传统的层级资源之间的差距。还提供了一组经典的 POLYGON 风格的树木和植被以满足你的需求。 该包还附带…

系统安全 - Linux /Docker 安全模型及实践

文章目录 导图Linux安全Linux 安全模型用户层权限管理的细节多用户环境中的权限管理文件权限与目录权限 最小权限原则的应用Linux 系统中的认证、授权和审计机制认证机制授权机制审计机制 小结 内网安全Docker安全1. Docker 服务隔离机制Namespace 机制Capabilities 机制CGroup…

JavaWeb - 8 - 请求响应 分层解耦

请求响应 请求(HttpServletRequest):获取请求数据 响应(HttpServletResponse):设置响应数据 BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程…

Oracle中MONTHS_BETWEEN()函数详解

文章目录 前言一、MONTHS_BETWEEN()的语法二、主要用途三、测试用例总结 前言 在Oracle数据库中,MONTHS_BETWEEN()函数可以用来计算两个日期之间的月份差。它返回一个浮点数,表示两个日期之间的整月数。 一、MONTHS_BETWEEN()的语法 MONTHS_BETWEEN(dat…

水下声呐数据集,带标注

水下声呐数据集,带标注 水下声呐数据集 数据集名称 水下声呐数据集 (Underwater Sonar Dataset) 数据集概述 本数据集是一个专门用于训练和评估水下目标检测与分类模型的数据集。数据集包含大量的水下声呐图像,每张图像都经过专业标注,标明…

vSAN05:vSAN延伸集群简介与创建、资源要求与计算、高级功能配置、维护、故障处理

目录 vSAN延伸集群延伸集群创建延伸集群的建议网络配置vSAN延伸集群的端口见证主机的资源要求vSAN延伸集群中见证节点带宽占用vSAN延伸集群的允许故障数vSAN延伸集群不同配置下的空间占用 vSAN延伸集群的HA配置vSAN延伸集群的DRS配置vSAN存储策略以及虚拟机/主机策略的互操作vS…

华为最新业绩出炉!上半年营收4175亿元,同比增长34%!

华为2024年上半年经营业绩分析:稳健发展,符合预期 [中国,深圳,2024年8月29日] 今日,华为发布了其2024年上半年的经营业绩,整体表现稳健,结果符合预期。在复杂多变的全球市场环境下,华为凭借强大的创新能力和市场洞察力,实现了销售收入和净利润的显著增长。 上半年,华…

C语言:预编译过程的剖析

目录 一.预定义符号和#define定义常量 二.#define定义宏 三.宏和函数的对比 四、#和##运算符 五、条件编译 在之前,我们已经介绍了.c文件在运行的过程图解,大的方面要经过两个方面。 一、翻译环境 1.预处理(预编译) 2.编译 3…

Spring Boot 整合 Keycloak

1、概览 本文将带你了解如何设置 Keycloak 服务器,以及如何使用 Spring Security OAuth2.0 将 Spring Boot 应用连接到 Keycloak 服务器。 2、Keycloak 是什么? Keycloak 是针对现代应用和服务的开源身份和访问管理解决方案。 Keycloak 提供了诸如单…

Unity MVC框架演示 1-1 理论分析

本文仅作学习笔记分享与交流,不做任何商业用途,该课程资源来源于唐老狮 1.一般的图解MVC 什么是MVC我就不说了,老生常谈,网上有大量的介绍,想看看这三层都起到什么职责?那就直接上图吧 2.我举一个栗子 我有…