机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/258406.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《山雨欲来-知道创宇 2023 年度 APT 威胁分析总结报告》

下载链接: https://pan.baidu.com/s/1eaIOyTk12d9mcuqDGzMYYQ?pwdzdcy 提取码: zdcy

【sgCreateTableColumn】自定义小工具:敏捷开发→自动化生成表格列html代码(表格列生成工具)[基于el-table-column]

源码 <template><!-- 前往https://blog.csdn.net/qq_37860634/article/details/136126479 查看使用说明 --><div :class"$options.name"><div class"sg-head">表格列生成工具</div><div class"sg-container"…

python in Vscode

背景 对于后端的语言选择&#xff1a; python&#xff0c;java&#xff0c;JavaScript备选。 选择Python 原因&#xff1a;可能是非IT专业的人中&#xff0c;会Python的人比较多。 目的 之前使用的IDE是VSCODE&#xff0c;在WSL的环境下使用。现在需要在在WSL的VSCODE下使…

使用Properties类读取配置文件

读取配置文件 使用Properties类读取配置文件。 Properties类本质上是个hashmap 常用方法 getProperty ( String key)&#xff1a; 用指定的键在此属性列表中搜索属性。也就是通过参数 key &#xff0c;得到 key 所对应的 value。load ( InputStream inStream)&#xff1a; 从输…

字符串拼接 - 华为OD统一考试(C卷)

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 200分 题解&#xff1a; Java / Python / C 题目描述 给定 M 个字符( a-z ) &#xff0c;从中取出任意字符(每个字符只能用一次)拼接成长度为 N 的字符串&#xff0c;要求相同的字符不能相邻。 计算出给定的字符列表…

Maui blazor ios 按设备类型设置是否启用safeArea

需求&#xff0c;新做了个app&#xff0c; 使用的是maui blazor技术&#xff0c;里面用了渐变背景&#xff0c;在默认启用SafeArea情况下&#xff0c;底部背景很突兀 由于现版本maui在SafeArea有点bug&#xff0c;官方教程的<ContentPage SafeAreafalse不生效&#xff0c;于…

【机器学习】数据清洗之识别重复点

&#x1f388;个人主页&#xff1a;甜美的江 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 &#x1f917;收录专栏&#xff1a;机器学习 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共同学习、交流进步…

STM32 HAL库 STM32CubeMX -- IWDG(独立看门狗)

STM32 HAL库 STM32CubeMX -- IWDG 一、IWDG简介二、独立看门狗的工作原理三、驱动函数初始化函数HAL IWDG Init()初始化函数HAL IWDG Init()其他宏函数 四、超时时间计算第一种办法第二种办法&#xff08;推荐&#xff09; 一、IWDG简介 看门狗(Watchdog)就是MCU上的一种特殊的…

SORA:OpenAI最新文本驱动视频生成大模型技术报告解读

Video generation models as world simulators&#xff1a;作为世界模拟器的视频生成模型 1、概览2、Turning visual data into patches&#xff1a;将视觉数据转换为补丁3、Video compression network&#xff1a;视频压缩网络4、Spacetime Latent Patches&#xff1a;时空潜在…

SAP PP学习笔记- 豆知识02 - 品目要谁来维护?怎么决定更不更新品目的数量金额?

其实都是在品目类型的Customize中设定的。 咱们这里简单试着说一下什么场景使用。 1&#xff0c;SAP中品目有很多View&#xff0c;都要由哪些部门来维护呢&#xff1f; 其实就是谁用谁维护呗。 在新建一个品目的时候&#xff0c;品目Type本身就决定了该品目要由哪些部门来维…

【STM32 CubeMX】串口编程DMA

文章目录 前言一、DMA方式1.1 DMA是什么1.2 CubeMX配置DMA1.3 DMA方式函数使用DMA的发送接收函数 总结 前言 在嵌入式系统中&#xff0c;串口通信是一项至关重要的功能&#xff0c;它允许单片机与外部设备进行数据交换&#xff0c;如传感器、显示器或其他设备。然而&#xff0…

【数据结构】16 二叉树的定义,性质,存储结构(以及先序、后序、中序遍历)

二叉树 一个二叉树是一个有穷的结点集合。 它是由根节点和称为其左子树和右子树的两个不相交的二叉树组成的。 二叉树可具有以下5种形态。 性质 一个二叉树第i层的最大结点数为 2 i − 1 2^{i-1} 2i−1, i ≥ 1 i \geq 1 i≥1 每层最大结点可以对应完美二叉树&#xff08;…

阿里云服务器租用收费标准价格表(2024年更新)

2024年最新阿里云服务器租用费用优惠价格表&#xff0c;轻量2核2G3M带宽轻量服务器一年61元&#xff0c;折合5元1个月&#xff0c;新老用户同享99元一年服务器&#xff0c;2核4G5M服务器ECS优惠价199元一年&#xff0c;2核4G4M轻量服务器165元一年&#xff0c;2核4G服务器30元3…

Gitee入门之工具的安装

一、gitee是什么&#xff1f; Gitee&#xff08;码云&#xff09;是由开源中国社区在2013年推出的一个基于Git的代码托管平台&#xff0c;它提供中国本土化的代码托管服务。它旨在为个人、团队和企业提供稳定、高效、安全的云端软件开发协作平台&#xff0c;具备代码质量分析、…

React18原理: 核心包结构与两大工作循环

React核心包结构 1 ) react react基础包&#xff0c;只提供定义 react组件(ReactElement)的必要函数一般来说需要和渲染器(react-dom,react-native)一同使用在编写react应用的代码时, 大部分都是调用此包的api比如, 我们定义组件的时候&#xff0c;就是它提供的class Demo ext…

springboot198基于springboot的智能家居系统

基于Springboot的智能家居系统 **[摘要]**社会和科技的不断进步带来更便利的生活&#xff0c;计算机技术也越来越平民化。二十一世纪是数据时代&#xff0c;各种信息经过统计分析都可以得到想要的结果&#xff0c;所以也可以更好的为人们工作、生活服务。智能家居是家庭的重要…

【排序算法】C语言排序(桶排序,冒泡排序,选择排序,插入排序,快速排序)

目录 什么是排序&#xff1f;1、桶排序 概念思路demo运行效果 2、冒泡排序 动图演示概念思路demo运行效果 3、选择排序 动图演示概念思路demo运行结果 4、插入排序 动图演示概念思路demo运行效果 5、快速排序 动图演示概念思路demo运行结果 什么是排序&#xff1f; 排序&…

“分布式透明化”在杭州银行核心系统上线之思考

导读 随着金融行业数字化转型的需求&#xff0c;银行核心系统的升级改造成为重要议题。杭州银行成功上线以 TiDB 为底层数据库的新一代核心业务系统&#xff0c;该实践采用应用与基础设施解耦、分布式透明化的设计开发理念&#xff0c;推动银行核心系统的整体升级。 本文聚焦…

easyx搭建项目-永七大作战(割草游戏)

永七大作战 游戏介绍&#xff1a; 永七大作战 游戏代码链接&#xff1a;永七大作战 提取码&#xff1a;ABCD 不想水文了&#xff0c;直接献出源码&#xff0c;表示我的诚意

Shellcode免杀对抗(Python)

Shellcode Python免杀&#xff0c;绕过360安全卫士、火绒安全、Defender Python基于cs/msf的上线 cs 执行代码2种可供选择 执行代码 1&#xff1a; rwxpage ctypes.windll.kernel32.VirtualAlloc(0, len(shellcode), 0x1000, 0x40) ctypes.windll.kernel32.RtlMoveMemory…