DCRNN解读(论文+代码)

一、引言

        作者首先提出:空间结构是非欧几里得且有方向性的,未来的交通速度受下游交通影响大于上游交通。虽然卷积神经网络(CNN)在部分研究中用于建模空间相关性,但其主要适用于欧几里得空间(例如二维图像),而非更为复杂的非欧几里得图结构。此外,现有的图卷积研究大多局限于无向图​。

        在此背景下,作者将交通传感器之间的空间相关性表示为有向图上的扩散过程,通过扩散卷积操作捕捉空间依赖性,提出了扩散卷积递归神经网络(DCRNN)。

二、方法

1. 空间依赖建模

        首先,在空间依赖建模上,使用了扩散模型(Diffusion Mode)。这个模型首先定义了一个马尔卡夫链。这是一个随机过程,用于模拟图上信息从一个节点传播到其他节点的方式,通过随机游走来捕捉节点间的空间依赖性。当马尔科夫过程经过多次迭代或多个时间步(步数达到一定程度)后,它会逐渐达到一个稳态分布。在这个稳态分布下,每个节点与其他节点的连接强度(或称扩散影响力)将变得稳定,不再随时间变化。

        关于扩散卷模型的更多知识,CSDN 这位博主讲的非常好:扩散模型 (Diffusion Model) 之最全详解图解-CSDN博客

        说回论文。在空间依赖建模上想要用到扩散模型(扩散卷积),其核心思想如下:

  • 对于每个节点 i,我们考虑它在不同步数 k 下从其他节点接收到的影响。
  • 对每一个步数 k,我们使用正向转移矩阵的 k 次幂和反向转移矩阵的 k 次幂来表示扩散的传播过程。
  • 在步数 k 时,通过 θ(k,1)​ 和 θ(k,2)​ 来控制正向和反向扩散的权重。
  • 最后,将每一步的结果加和,以捕捉多步扩散过程中的节点间依赖关系。

        有了这个思想,就不难理解论文中图信号 X 与滤波器 fθ​ 的扩散卷积操作定义:

        式中,X 是一个 N×P 的矩阵,X:,p 就表示第 p 个节点的所有特征值(如速度)。fθ​ 是扩散卷积的滤波器,其作用是控制和调整扩散卷积的影响范围和特性(类似于图卷积中的卷积核)。W 是图的加权邻接矩阵。Do​ 是 W 的出度对角矩阵,表示每个节点的出度。那么这两项相乘后的矩阵表示一个随机游走过程,也就是说,矩阵的每个元素表示从节点 i 到节点 j 的条件概率,即在随机游走中,从 i 到达 j 的概率。

        那么结合这个公式,作者的扩散卷积代码就很容易理解:

with tf.variable_scope(scope):if self._max_diffusion_step == 0:  # 根据 _max_diffusion_step 控制扩散层数,0 表示无扩散passelse:for support in self._supports:# 将 support(稀疏邻接矩阵)与 x0(或更新的 x1)相乘,模拟信息在图上扩散。x1 = tf.sparse_tensor_dense_matmul(support, x0)x = self._concat(x, x1)# 将扩散结果 x1、x2 依次拼接到 x 上for k in range(2, self._max_diffusion_step + 1):x2 = 2 * tf.sparse_tensor_dense_matmul(support, x1) - x0  # 切比雪夫多项式算法x = self._concat(x, x2)x1, x0 = x2, x1# 合并扩散结果:(batch_size * num_nodes, input_size * num_matrices)num_matrices = len(self._supports) * self._max_diffusion_step + 1  # Adds for x itself.x = tf.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])x = tf.transpose(x, perm=[3, 1, 2, 0])  # (batch_size, num_nodes, input_size, order)x = tf.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])# 应用权重和偏置,得到卷积输出weights = tf.get_variable('weights', [input_size * num_matrices, output_size], dtype=dtype,initializer=tf.contrib.layers.xavier_initializer())x = tf.matmul(x, weights)  # (batch_size * self._num_nodes, output_size)biases = tf.get_variable("biases", [output_size], dtype=dtype,initializer=tf.constant_initializer(bias_start, dtype=dtype))x = tf.nn.bias_add(x, biases)

        先通过 utils.calculate_random_walk_matrix(adj_mx).T 计算出(Do逆)与 W 的乘积 support (并且转化为了稀疏矩阵以便高效运算)。在上面的代码中,x0、x1 代表着代表着不同时间步的特征,也就是公式中的 X:,p;x2 是切比雪夫多项式的算法(优化拉普拉斯矩阵的高次幂计算);_max_diffusion_step 就是公式中的k,代表扩散步数;最后 weights 也就是公式中的 θk,代表扩散权重。

        基于上述卷积操作,可以构建一个扩散卷积层,增强模型的表达能力。将 P-维特征映射到 Q-维输出。那么输出矩阵X(N×P)经过激活函数a,就转化为了输出矩阵H(N×Q)。

# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
return tf.reshape(x, [batch_size, self._num_nodes * output_size])

2. 时间动态建模

        在时间依赖建模中,作者使用了递归神经网络(RNN)的变体——门控循环单元(GRU)。并且使用扩散卷积替换了 GRU 中的矩阵乘法。那么定义如下:

        上述式子中,∗G​ 表示扩散卷积(用扩散卷积去处理 Xt 和 Ht),Θr、Θu、ΘC​ 是相应的滤波器参数(也就是原始 GRU 中的权重参数)。那么接下来就能像 GRU 那样进行多步预测。

        接下来看这部分的代码实现:

with tf.variable_scope(scope or "dcgru_cell"):  # 添加变量的作用域(前缀)# 1.计算更新门u和重置门rwith tf.variable_scope("gates"):output_size = 2 * self._num_units# We start with bias of 1.0 to not reset and not update.# 判断使用哪种方法计算更新门和重置门if self._use_gc_for_ru:fn = self._gconvelse:fn = self._fcvalue = tf.nn.sigmoid(fn(inputs, state, output_size, bias_start=1.0))# 拆分并调整重置门和更新门的形状value = tf.reshape(value, (-1, self._num_nodes, output_size))r, u = tf.split(value=value, num_or_size_splits=2, axis=-1)r = tf.reshape(r, (-1, self._num_nodes * self._num_units))u = tf.reshape(u, (-1, self._num_nodes * self._num_units))# 2.计算候选状态 cwith tf.variable_scope("candidate"):c = self._gconv(inputs, r * state, self._num_units)if self._activation is not None:c = self._activation(c)# 3. 计算输出和新状态output = new_state = u * state + (1 - u) * c

        不管是第1步计算更新门 u 和 重置门 r ,还是第2步计算候选状态,都用到了扩散卷积函数 _gconv(公式中的 *G )。其中,偏置值不设置时默认为0,激活函数不设置时默认为 tanh。

        在多步预测中,模型在生成每个步骤的预测时,依赖前一步的输出,但如果某一步的预测错误,会导致后续预测受到影响,从而引发错误逐步积累,最终显著降低预测精度。因此,作者团队为了缓解训练和测试期间输入分布不一致的问题,引入了计划抽样方法​(Scheduled Sampling)。在训练过程中,计划抽样不是每次都让模型在每一步中直接使用前一步的真实观测值,而是引入一个采样概率 ϵ,按一定概率从真实观测值中抽样,按另一概率从模型的预测结果中抽样。随着训练的进行,这个采样概率逐渐从依赖真实观测值过渡到依赖预测值,最终在测试阶段模型只依赖于自己的预测。也就是说,采样概率 ϵ 会逐渐从1将为0。

        计划抽样的代码如下:

# 控制每一步解码输入是使用模型的预测结果 prev,还是使用真实的标签值 labels[i]
def _loop_function(prev, i):if is_training:# Return either the model's prediction or the previous ground truth in training.if use_curriculum_learning:  # 使用课程学习(模仿人类学习的特点,由简单到困难来学习课程)c = tf.random_uniform((), minval=0, maxval=1.)# 基于全局步数 global_step 计算采样阈值 thresholdthreshold = self._compute_sampling_threshold(global_step, cl_decay_steps)# 当随机数 c 小于 threshold 时,选择 labels[i](真实值);否则使用 prev(预测值)result = tf.cond(tf.less(c, threshold), lambda: labels[i], lambda: prev)else:result = labels[i]else:# Return the prediction of the model in testing.result = prevreturn result

3. 代码结构

         作者在 DCRNN 模型设计上分为了三个代码文件,分别是 dcrnn_cell.py、dcrnn_model.py和dcrnn_supervisor.py。一般而言,cell 文件通常定义的是一个神经网络中的基本计算单元模块。而 model 文件定义了整个神经网络模型的结构,它将各个 cell 组合起来,实现从输入到输出的完整计算图。supervisor 文件通常负责训练和评估的流程管理,它调用 model 文件中的模型进行训练和推理,设置优化流程,监控训练状态。

        dcnn_cell.py 中主要实现了扩散卷积和 GRU 的计算。其核心也就是上面1和2部分的代码。

        dcrnn_model.py 的代码主要实现以下几个功能:

  1. 将 DCGRUCell 聚合为一个多层的 GRU 单元
  2. 设置训练时的特殊方法(例如上文的计划抽样)
  3. “编码(encoding)”和“解码(decoding)”

        什么是编码和解码呢?它是指在序列到序列(Seq2Seq)模型中,将输入数据转换为潜在表示(编码)并生成输出序列(解码)的过程。在普通的 RNN 中,输入和输出的处理方式是逐时间步的,每个时间步的输入都会产生一个对应的输出。这只适用于固定长度的输入和输出序列,在不同长度的输入输出序列上表现不佳。相反, Seq2Seq 这种结构比普通的 RNN 更适合处理不同长度的输入和输出序列,尤其适合于交通预测等多步预测任务。具体知识点可见这位博主的讲解:Seq2Seq 模型详解_seq2seq模型-CSDN博客。

        而在 DCRNNModel 类中,编码器将输入数据处理成隐藏状态(enc_state),这个状态浓缩了输入的特征;解码器enc_state 作为起点,逐步生成未来时刻的预测值。这部分的代码在后文实验对比会提到,代码如下所示:

# 创建多层RNN单元
encoding_cells = [cell] * num_rnn_layers  # 在编码阶段将使用多个相同的RNN单元
decoding_cells = [cell] * (num_rnn_layers - 1) + [cell_with_projection]  # 在解码的最后一层使用具有输出投影的单元,以确保输出维度正确。
encoding_cells = tf.contrib.rnn.MultiRNNCell(encoding_cells, state_is_tuple=True)
decoding_cells = tf.contrib.rnn.MultiRNNCell(decoding_cells, state_is_tuple=True)# 构建编码器和解码器
_, enc_state = tf.contrib.rnn.static_rnn(encoding_cells, inputs, dtype=tf.float32)
outputs, final_state = legacy_seq2seq.rnn_decoder(labels, enc_state, decoding_cells, loop_function=_loop_function)

        dcrnn_supervisor.py 的代码主要实现以下几个功能:

  1. 初始化参数配置
  2. 配置日志系统,便于后续调试和复现。
  3. 数据准备
  4. 通过 DCRNNModel初始化训练和测试模型的对象
  5. 初始化学习率变量
  6. 配置优化器(默认Adam优化器)
  7. 定义损失函数
  8. 配置梯度裁剪与优化操作
  9. 配置模型保存器

        其中第4步是这一个python文件的关键代码。

# 4. 构建模型
scaler = self._data['scaler']  # 标准化
with tf.name_scope('Train'):  # 训练模式with tf.variable_scope('DCRNN', reuse=False):self._train_model = DCRNNModel(is_training=True, scaler=scaler,batch_size=self._data_kwargs['batch_size'],adj_mx=adj_mx, **self._model_kwargs)with tf.name_scope('Test'):  # 测试模式with tf.variable_scope('DCRNN', reuse=True):self._test_model = DCRNNModel(is_training=False, scaler=scaler,batch_size=self._data_kwargs['test_batch_size'],adj_mx=adj_mx, **self._model_kwargs)

三、相关工作

        作者首先提出了以往研究的一些缺陷:

        本文提出的 DCRNN 与上述方法不同:它将传感器网络建模为加权有向图,并利用扩散卷积捕捉空间依赖关系。通过在卷积中结合双向随机游走,DCRNN 能够更灵活地捕捉上游和下游的交通影响。此外,DCRNN 结合序列到序列学习框架及计划抽样技术,以更好地处理长期预测中的误差累积问题。

四、实验

        作者使用了两个数据集实验,分别是 METR-LA 数据集和 PEMS-BAY 数据集。其中70% 的数据用于训练,20% 用于测试,剩余 10% 用于验证。

        作者分别在后文分别讨论了时间建模和空间建模的效果。

        首先在空间依赖建模上,选取了 DCRNN 的变体—— DCRNN-NoConv 和 DCRNN-UniConv。前者忽略空间依赖,后者使用单向游走(欧几里得图结构)。实验对比如下:

        如果在空间依赖建模上不适用传播卷积而使用切比雪夫图卷积(GCRNN),那么结果也是显而易见的。

        在时间依赖建模上,作者使用了 DCNN 和 DCRNN-SEQ 来做对比。前者是静态输入的卷积神经网络,而后者加入了 Seq2Seq 框架处理。而本文使用的 DCRNN 是在 DCRNN-SEQ 的基础上添加了计划抽样方法。经过实验,DCRNN 的效果表现最好。

五、总结

        总结这篇论文的创新点如下:

  1. 使用传播模型的双向随机游走建立空间模型;
  2. 使用 GRU 捕捉时间动态;
  3. 结合了编码器-解码器架构;
  4. 计划抽样技术。 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/464615.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

StandardThreadExecutor源码解读与使用(tomcat的线程池实现类)

🏷️个人主页:牵着猫散步的鼠鼠 🏷️系列专栏:Java源码解读-专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 目录 目录 1.前言 2.线程池基础知识回顾 2.1.线程池的组成 2.2.工作流程 2…

Unreal5从入门到精通之如何解决在VR项目在头显中卡顿的问题

前言 以前我们使用Unity开发VR,Unity提供了非常便利的插件和工具来做VR。但是由于Unity的渲染效果不如Unreal,现在我们改用Unreal来做VR了,所有的VR相关的配置和操作都要重新学习。 今天就来总结一下,我在开发VR过程中碰到的所有问题。 1.编辑器,以VR运行 默认运行方式…

centos7 kafka高可用集群安装及测试

前言 用三台虚拟机centos7 搭建高可用集群,及测试方法 高可用搭建的方法,参考:https://blog.csdn.net/u011197085/article/details/134070318 高可用搭建 1、安装配置zookeeper集群 下载zookeeper 注:zookeeper链接如果失效&a…

Redis(2):内存模型

一、Redis内存统计 工欲善其事必先利其器,在说明Redis内存之前首先说明如何统计Redis使用内存的情况。 在客户端通过redis-cli连接服务器后(后面如无特殊说明,客户端一律使用redis-cli),通过info命令可以查看内存使用情…

C++笔试题之实现一个定时器

一.定时器(timer)的需求 1.执行定时任务的时,主线程不阻塞,所以timer必须至少持有一个线程用于执行定时任务 2.考虑到timer线程资源的合理利用,一个timer需要能够管理多个定时任务,所以timer要支持增删任务…

0.STM32F1移植到F0的各种经验总结

1.结构体的声明需放在函数的最前面 源代码: /*开启时钟*/RCC_APB2PeriphClockCmd(RCC_APB2Periph_USART1, ENABLE); //开启USART1的时钟RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA, ENABLE); //开启GPIOA的时钟/*GPIO初始化*/GPIO_InitTypeDef GPIO_InitStructu…

在Microsoft Outlook日历中添加多个时区

在Microsoft Outlook日历中添加多个时区 1.单击Outlook中的文件选项卡,单击选项 2.左侧菜单中选择日历 3.向下滚动到时区部分,并标记当前时区,比如China 4.选中“显示第二个时区”框 5.选择第二个时区并给它一个标签,比如Germa…

为啥学习数据结构和算法

基础知识就像是一座大楼的地基,它决定了我们的技术高度。而要想快速做出点事情,前提条件一定是基础能力过硬,“内功”要到位。 想要通关大厂面试,千万别让数据结构和算法拖了后腿 我们学任何知识都是为了“用”的,是为…

爬虫学习4

from threading import Thread#创建任务 def func(name):for i in range(100):print(name,i)if __name__ __main__:#创建线程t1 Thread(targetfunc,args("1"))t2 Thread(targetfunc, args("2"))t1.start()t2.start()print("我是诛仙剑")from …

【Maven】——基础入门,插件安装、配置和简单使用,Maven如何设置国内源

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 引入: 一:Maven插件的安装 1:环境准备 2:创建项目 二…

Vue中使用echarts生成地图步骤详解

1.创建容器元素 <div class"map" id"map" style"width:1000px;height:1000px;"></div> 2.Vue项目引入world.js(我这里的演示是世界地图&#xff0c;不同地图对应js文件不一样) world.js文件包含&#xff1a; 地理坐标数据&#xff…

docker安装低版本的jenkins-2.346.3,在线安装对应版本插件失败的解决方法

提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、网上最多的默认解决方法1、jenkins界面配置清华源2、替换default.json文件 二、解决低版本Jenkins在线安装插件问题1.手动下载插件并导入2.低版本jenkins在…

算法专题:栈

目录 1. 删除字符串中的所有相邻重复项 1.1 算法原理 1.2 算法代码 2. 844. 比较含退格的字符串 2.1 算法原理 2.2 算法原理 3. 基本计算器 II 3.1 算法原理 3.2 算法代码 4. 字符串解码 4.1 算法原理 4.2 算法代码 5. 验证栈序列 5.1 算法原理 5.2 算法代码 1.…

ZDH权限-扩展支持数据权限

目录 项目源码 预览地址 安装包下载地址 ZDH权限模块 ZDH权限扩展更细粒度方案 第一种方案&#xff1a; 第二种方案&#xff1a; ZDH权限扩展支持数据权限-新增属性 总结 感谢支持 项目源码 zdh_web: GitHub - zhaoyachao/zdh_web: 大数据采集,抽取平台 预览地址 后…

交换机的基本配置

交换机的基本配置 实验题目实验目的实验任务实验设备实验环境实验步骤VLAN 的简单配置跨交换机 vlan 的配置主机配置信息表解释&#xff1a; vlan 间路由 实验题目 交换机的基本配置。 实验目的 1) 理解交换机的原理和应用场景&#xff1b; 2) 交换机的基本指令系统&#xf…

借助 Aspose.Words,使用 C# 从 Word 文档中删除页面

如果您正在寻找一种快速删除 Word 文档中不相关、过时或空白页的方法&#xff0c;那么您来对地方了。在这篇博文中&#xff0c;我们将学习如何使用 C# 从 Word 文档中删除页面。我们将逐步引导您完成该过程&#xff0c;提供清晰的示例&#xff0c;以帮助您以编程方式高效地从 W…

华为 HarmonyOS NEXT 原生应用开发: 动画的基础使用(属性、显示、专场)动画

2024年11月5日 LiuJinTao 文章目录 鸿蒙中动画的使用一、属性动画 - animation属性动画代码示例 二、显示动画 - AnimateTo三、专场动画 鸿蒙中动画的使用 一、属性动画 - animation 属性动画代码示例 /*** 属性动画的演示*/ Entry Component struct Index {State selfWidth:…

基于STM32的手式电视机遥控器设计

引言 本项目基于STM32微控制器设计了一个手式电视机遥控器系统&#xff0c;通过集成加速度传感器和陀螺仪&#xff0c;实现手势识别和遥控功能。该遥控器系统可以通过简单的手势操作实现对电视机的音量调节、频道切换和开关机控制等功能。项目涉及到硬件设计、手势识别算法和红…

基于SpringBoot+微信小程序+协同过滤算法+二维码订单位置跟踪的农产品销售平台-新

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; “农产品商城”小程序…

论文阅读-用于点云分析的自组织网络

目前存在的问题&#xff1a; 原始的SOM&#xff08;1&#xff09;训练结果与初始节点高度相关&#xff08;2&#xff09;样本更新规则取决于输入点的顺序3D 卷积神经网络&#xff08;需要将数据转换为体素&#xff0c;存在分辨率损失和计算成本上涨的问题&#xff09;、PointN…