机器学习笔记:门控循环单元的建立

目录

介绍

结构

模型原理

重置门与更新门

候选隐状态

输出隐状态

模型实现

引入数据

初始化参数

定义模型

训练与预测

简洁实现GRU

思考


介绍

门控循环单元(Gated Recurrent Unit,简称GRU)是循环神经网络一种较为复杂的构成形式,其用途也是处理时序数据,相比具有单隐藏状态的RNN,GRU具有忘记的能力,可以忘记无用的数据。

结构

与传统RNN相比,GRU的结构引入了的概念,比RNN复杂许多,不过可以看出,其输入仍然是X_t和上一时间步隐状态H_{t-1},输出仍然是本时间步隐状态H_t。区别在于“细胞”内部结构,RNN只需要将H和X分别处理,之后结合在一起,激活函数激活后将其输出即可。而GRU内部处理十分复杂。

模型原理

我们以处理的顺序来依次讲解各个组成部分的模型原理。

重置门与更新门

首先介绍重置门(reset gate)R_t更新门(update gate)Z_t。 我们把它们设计成(0,1)区间中的向量。 重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。后面还会再提到两个门的具体作用。

重置门和更新门的计算公式如下所示,由于使用sigmoid函数,R_tZ_t的值在(0,1)区间内。

R_t= \sigma (X_t \cdot W_{xr} + H_{t-1} \cdot W_{hr}+b_r)

Z_t= \sigma (X_t \cdot W_{xz} + H_{t-1} \cdot W_{hz}+b_z)

候选隐状态

候选隐状态的计算公式如下,是RNN中计算公式的升级版。(\bigodot是哈达玛积)

\tilde{H}_t = tanh(X_t\cdot W_{xh}+(H_{t-1}\bigodot R_t)\cdot W_{hh}+b_h)

当重置门R的值接近1时,则候选隐状态的计算与RNN一致,当重置门R的值接近0时,则候选隐状态计算时会完全“忘记”之前的值。

输出隐状态

输出隐状态需要更新门,候选隐状态和上一阶段隐状态共同计算得到。

H_t=Z_t \bigodot H_{t-1}+ (1-Z_t)\bigodot \tilde{H_t}

由公式可以看出,当Z_t接近0时,隐状态即为候选隐状态,当Z_t接近1时,隐状态即为上一阶段隐状态,更新门决定隐状态中有多少部分进行更新。

模型实现

引入数据

我们从零开始实现一个GRU,首先引入相关的库,并定义相关的一系列超参数。

from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2lnpx.set_np()batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化参数

将需要学习的参数进行初始化。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return np.random.normal(scale=0.01, size=shape, ctx=device)def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),np.zeros(num_hiddens, ctx=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = np.zeros(num_outputs, ctx=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.attach_grad()return params

定义模型

定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

def init_gru_state(batch_size, num_hiddens, device):return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )
def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = np.dot(H, W_hq) + b_qoutputs.append(Y)return np.concatenate(outputs, axis=0), (H,)

训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

perplexity 1.1, 10510.3 tokens/sec on gpu(0)
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

简洁实现GRU

mxnet框架中自带GRU的API,可以直接调用。GRU唯一需要的参数就是隐藏单元的数量。

接下来根据上一篇文章中定义好的train_ch8进行反向计算更新参数并进行预测即可。

gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

perplexity 1.1, 183591.3 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

思考

  1. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?

  2. 比较rnn.RNNrnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/398056.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【网络编程】UDP通信基础模型实现

udpSer.c #include<myhead.h> #define SER_IP "192.168.119.143" #define SER_PORT 7777 int main(int argc, const char *argv[]) {//1.创建int sfd socket(AF_INET,SOCK_DGRAM,0);if(sfd -1){perror("socket error");return -1;}//2.连接struct…

element-ui周选择器,如何获取年、周、起止日期?

说明 版本&#xff1a;vue2、element-ui2.15.14 element-ui的日期选择器可以设为周&#xff0c;即typeweek&#xff0c;官方示例如下&#xff1a; 如果你什么都不操作&#xff0c;那么获取的周的值为&#xff1a; value1: Tue Aug 06 2024 00:00:00 GMT0800 (中国标准时间)如…

asp.net医院权限管理系统

医院管理的设计与实现程序 医院管理系统asp.netsqlserver 医院权限管理系统sqlserver 挂号管理 挂号类型管理 挂号登记 挂号查询 药品管理 计量单位管理药 品分类管理 药品编辑 病人资料 病人资料录入 病人资料编辑 病人资料查询 住院管理 住院登记 住院查询办理出院 病例管理 …

鸿蒙HarmonyOS开发:如何灵活运用动画效果提升用户体验

文章目录 一、动画概述1、动画的目的 二、显式动画 (animateTo)1、接口2、参数3、AnimateParam对象说明4、示例5、效果 三、属性动画 (animation)1、接口2、参数3、AnimateParam对象说明4、系统可动画属性4、示例5、效果 一、动画概述 动画的原理是在一个时间段内&#xff0c;…

HAProxy原理及实例

目录 目录 haproxy简介 haproxy的基本信息 haproxy下载并查看版本 haproxy的基本配置信息 global配置 ​编辑多进程和多线程 启用多进程 启用多线程 haproxy开启多线程和多进程有什么用 proxies配置 defaults frontend backend listen socat工具 实例&#xff1a…

Particle Swarm Optimization粒子群算法

目录 1.粒子群算法入门 1.1 简单的优化问题 1.1.1 盲目搜索 1.1.2 粒子群算法流程图 1.1.3 粒子群算法的核心公式 1.1.4 预设参数 1.1.5 初始化粒子的位置和速度 1.1.6 计算适应度 1.1.7 循环体&#xff1a;更新粒子速度和位置 1.1.8 模型改进 2.深入研究粒子群算法 …

CLEFT 基于高效大语言模型和快速微调的语言-图像对比学习

CLEFT: Language-Image Contrastive Learning with Efficient Large Language Model and Prompt Fine-Tuning github.com paper CLEFT是一种新型的对比语言图像预训练框架&#xff0c;专为医学图像而设计。它融合了医学LLM的预训练、高效微调和提示上下文学习&#xff0c;展…

【Linux】线程同步与互斥

目录 线程相关问题 线程安全 常见的线程安全的情况 常见的线程不安全的情况 可重入函数与不可重入函数 常见不可重入的情况 常见可重入的情况 可重入与线程安全的关系 联系 区别 线程同步与互斥 互斥锁 使用 死锁 死锁的四个必要条件 如何避免死锁 条件变量 同…

Unity读取Android外部文件

最近近到个小需求,需要读Android件夹中的图片.在这里做一个记录. 首先读写部分,这里以图片为例子: 一读写部分 写入部分: 需要注意的是因为只有这个地址支持外部读写,所以这里用到的地址都以 :Application.persistentDataPath为地址起始. private Texture2D __CaptureCamera…

【JavaEE】初步认识多线程

&#x1f525;个人主页&#xff1a; 中草药 &#x1f525;专栏&#xff1a;【Java】登神长阶 史诗般的Java成神之路 &#x1f3b7; 一.线程 1.概念 线程&#xff08;Thread&#xff09;是在计算机科学中&#xff0c;特别是操作系统领域里的一个关键概念。它是操作系统能够进行…

Android中的Binder

binder是Android平台的一种跨进程通信&#xff08;IPC&#xff09;机制&#xff0c;从应用层角度来说&#xff0c;binder是客户端和服务端进行通信的媒介。 ipc原理 ipc通信指的是两个进程之间交换数据&#xff0c;如图中的client进程和server进程。 Android为每个进程提供了…

【聚类算法】

聚类算法是一种无监督学习方法&#xff0c;用于将数据集中的数据点自动分组到不同的类别中&#xff0c;这些类别也称为“簇”或“群”。聚类的目标是让同一簇内的数据点尽可能相似&#xff0c;而不同簇之间的数据点尽可能不相似。聚类算法广泛应用于多种领域&#xff0c;如数据…

xtrabackup搭建MySQL 8.0 主从复制

xtrabackup搭建MySQL 8.0 主从复制 安装MySQL 8.0.37安装xtrabackupGTIDs初始化从库参考&#xff1a;GTID概述GTID相较与传统复制的优势GTID自身存在哪些限制GTID工作原理简单介绍如何开启GTID复制GTID与传统模式建立复制时候语句的不同点传统复制GTID复制 GTID同步状态简单解析…

Linux系统编程 day09 线程同步

Linux系统编程 day09 线程同步 1.互斥锁2.死锁3.读写锁4.条件变量&#xff08;生产者消费者模型&#xff09;5.信号量 1.互斥锁 互斥锁是一种同步机制&#xff0c;用于控制多个线程对共享资源的访问&#xff0c;确保在同一时间只有一个线程可以访问特定的资源或执行特定的操作…

机器学习第一课

1.背景 有监督学习&#xff1a;有标签&#xff08;连续变量&#xff08;回归问题&#xff1a;时间序列等&#xff09;、分类变量&#xff08;分类&#xff09;&#xff09; 无监督学习&#xff1a;没有标签&#xff08;聚类、关联&#xff08;相关性分析&#xff1a;哪些相关…

代码随想录算法训练营Day35 | 01背包问题 | 416. 分割等和子集

今日任务 01背包问题 题目链接&#xff1a; https://kamacoder.com/problempage.php?pid1046题目描述&#xff1a; Code #include <iostream> #include <vector> #include <functional> #include <algorithm>using namespace std;int main(void)…

工作随记:我在OL8.8部署oracle rac遇到的问题

文章目录 一、安装篇问题1&#xff1a;[INS-08101] Unexpected error while executing the action at state:supportedosCheck问题1解决办法&#xff1a;问题2&#xff1a;[INS-06003] Failed to setup passwordless SSH connectivity with thefollowing nodeis): [xxxx1, xxxx…

go语言后端开发学习(四) —— 在go项目中使用Zap日志库

一.前言 在之前的文章中我们已经介绍过如何使用logrus包来作为我们在gin框架中使用的日志中间件&#xff0c;而今天我们要介绍的就是我们如何在go项目中如何集成Zap来作为日志中间件 二.Zap的安装与快速使用 和安装其他第三方包没什么区别&#xff0c;我们下载Zap包只需要执…

pod详解 list-watch机制 预选优选策略 如何指定节点调度pod

K8S是通过 list-watch 机制实现每个组件的协同工作 controller-manager、scheduler、kubelet 通过 list-watch 机制监听 apiserver 发出的事件&#xff0c;apiserver 也会监听 etcd 发出的事件 scheduler的调度策略&#xff1a; 预选策略&#xff08;Predicates&#xff09;…

Pytorch_cuda版本的在线安装命令

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 运行效果如下&#xff1a; 这个方法是直接从pytorch官网进行在线下载和安装。 cu121&#xff0c;表示当前您安装的cuda版本是12.1