ST-SSL:基于自监督学习的交通流预测模型

文章信息

文章题为“Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction”,是一篇发表于The Thirty-Seventh AAAI Conference on Artificial Intelligence (AAAI-23)的一篇论文。该论文主要针对交通流预测任务,结合自监督学习,衡量数据的时空异质性。

摘要

在智能交通系统中,准确预测不同时间段的城市交通流量是至关重要的。现有的方法存在两个关键的局限性:1、大多数模型集中预测所有区域的交通流量,而没有考虑空间异质性,即不同区域的交通流量分布可能存在偏差;2、现有模型无法捕捉时变交通模式引起的时间异质性,大多数现有模型通常是在所有时间段内与共享参数化空间进行时间相关性建模。为解决上述问题,文章提出了一种新的时空自监督学习(ST-SSL)的预测框架,该框架通过辅助的自监督学习范式,增强了交通模式表征,以反映时空异质性。具体而言,该模型构建在一个集成模块上,具有时间卷积和空间卷积。为实现自适应时空自监督学习,ST-SSL在属性层面和结构层面对交通流量图数据进行自适应增强。在增强的流量图的基础上,文章构建了两个基于自监督学习的辅助任务,通过时空异构感知增强对主要流量预测任务进行补充。文章的主要贡献如下:

  • 该文章是第一个提出一种新的自监督学习框架来模拟交通流预测的时空异质性。所提出的预测框架可能会对其他实际的时空应用(例如空气质量预测)有所帮助。

  • 文章提出了一种基于图结构时空图的自适应异构感知数据增强方案,以减弱噪声扰动对预测的影响。

  • 文章引入两个自监督学习任务来补充主要的交通预测任务,通过增强模型识别能力和对交通时空异质性的认识。

基本概念

空间区域文章将网络划分为N=I*J的网格,89a24e88b76dce1641e61ccff59ffc9a.png表示区域。

交通流量图(TFG)交通流量图定义为340fc7e02f61e2b549526705563c10ea.png,其中V表示节点集合,为边集合,44d9abbfa667c691b869c39484dd6a22.png为邻接矩阵,fbde9e51f6c2ce8ae486d5b56cd7a442.png表示历史T个时间步交通网络内各个区域的流入量和流出量序列。

问题定义给定若干历史时间步的交通流量图,文章所研究问题的目标是学习一个能够准确估计未来一个时间步内的所有区域的交通量的预测的函数。

模型框架

文章所提出的模型框架如下图中的图(a)所示。主要包括时空编码器(ST Encoder)、自适应图增广模块(Adaptive Graph Augmentation)。

5d1ebbdc34170f661cd23f77360e2739.png

时空编码器(ST Encoder):文章主要利用卷积神经网络构建时空编码器,将时间卷积分量与图卷积传播网络相结合,作为时空关系表示的主干。具体而言,时间特征提取层由基于门控机制的一维因果卷积构成,时间编码器沿时间维度进行卷积,如下式所示。

7702e7939bc6574cb9b30481c991ef39.png

其中,27797d298409f1b6e63d8f05e5e77971.png表示第t个时间步交通网络的嵌入矩阵,79a9818c2fd6d3b56ffbdf14850036b8.png表示其中的第n行,即与区域b7a346e11a13104dfd4b7425513c7c0f.png相关,D表示嵌入维度。进一步,空间特征提取层如下式所示。

45868f249ce5fa85fd778a43a3dcb133.png

文章的时空编码器采用“三明治”块结构构建,即以TC→SC→TC的顺序对数据进行处理。随着时空编码器的处理,时间维度最终变为0,从而得到最终的预测结果066c98af435b1cbb77344e01b9ba7a9b.png

自适应图增广模块(Adaptive Graph Augmentation on TFG):文章设计了两阶段的图增广方法,即流量级数据增广和图拓扑级结构增广。首先,文章给出了不同区域异质性的衡量方法。具体而言,对于区域,其嵌入序列的计算方式如下。

ad6b15d9408f75edcf6585b519ba068d.png

其中,指权重4544b8b906939f9f78fba8c387c89d87.png不同时间步嵌入序列的聚合表示,9958e46dc2136314f1f8582ed72265bf.png是可学习参数。该权重反映不同时间步与总流量规律的关系。基于上式,文章通过比对两个不同区域对应总流量规律的差异,从而反映不同区域的异质性。具体如下式所示,该值越大,则两个区域之间的相关性越强,因此异质性越小。

cdb561e39e48e0b6ceb417936a6a004d.png

基于上式,文章提出流量级数据增广图拓扑级结构增广。具体而言,流量级数据增广旨在基于概率掩盖第个时间步中相关性较弱的流量,其中服从二项分布582bcfcd4affdc3466589d3d220f21fc.png,其处理结果为bd781805ffeadac6d62102e4253a5d2e.png图拓扑级结构增广旨在对网络内所有区域进行分析,包括两个步骤:1、若两个区域的流量规律不是高度相关,即异质性较大,基于概率掩盖这两个区域之间的连接,服从二项分布352a94d30360c7ab0252b3dd570eeb36.png;2、若两个区域之间的流量规律异质性较小,则会依照服从二项分布da4f935945f7bd8b84f581261191173a.png的概率添加一条边。基于上述两阶段数据增广,得到新的TFG,如下所示。

2396d8bbaae6bc2f7c6fe3953a6e354d.png

基于自监督学习的空间异质性建模:给定经过增广的TFG,文章的目标是使区域嵌入在辅助自监督信号的情况下有效地保持空间异质性。为实现该目标,文章在区域级别上设计了一个基于软聚类的自监督学习任务,将区域映射到对应于不同城市区域功能的多个潜在表示空间。具体而言,文章生成K个聚类嵌入3e52a40767e8e0332d2650093cee5bbc.png。聚类过程如下式所示。

2b9bc48f53899088828a51f9c16c078e.png

其中,为D维向量,表示区域的区域嵌入。基于上式,区域的聚类指派如下。

c87eb2376edff56381309525c601deaa.png

为生成自监督学习的特征,所设计的辅助任务旨在预测用于原始区域嵌入生成的区域指派88123d815230f1b87669aa134a4f568d.png。对于区域,对应的自监督学习的损失计算方法如下所示。

bf29bd1cd899aa2a9f1959a973686bb7.png

对于所有区域而言,其损失计算如下。

588d58ad1873c9ae7d7042f0f8066324.png

上述聚类方法存在两方面问题:首先,生成的聚类指派矩阵是由聚沉成绩所产生的,每个区域的聚类指派求和可能不为1;其次,可能存在每个区域都有相同的分配。为解决上述我呢提,文章提出区域聚类的分布正则化。具体而言,文章采用最大熵原理,即9ddf465d5aa2df37e8ad2f10dc845343.png,定义可行解所构成的集合如下。

d523fbd71cbee12fab4c7f17d4077987.png

对于每一个可行的聚类指派,可以将嵌入矩阵355cb95cf103b73bd6857f5190f92378.png映射为聚类矩阵999fe3b1d41121f2bb004b25dcdc62fe.png。因此搜索可以通过最大化嵌入矩阵和聚类的相似度获得最优解,如下式所示。

4736812e12bd289a454d7d3f3aedb880.png

其中,tr()表示矩阵的迹,c870b96907b0fb15711c36a9e54e04ba.png表示熵函数,计算公式如下。

c3e067a9aa40972e21e37ea41fccf45e.png

基于自监督学习的时间异质性建模:文章进一步设计了一个自监督学习任务,通过强制时间步长特定的流量模式表示之间的差异,将时间异质性注入到时间感知区域嵌入中。具体而言,文章首先融合原始的TFG和增广后的TFG。

bccc9d7edcff58b1a39cbf2144dbf02a.png

进一步,将不同区域的特征聚合,从而获得第t个时间步网络级表示,计算方式如下。

ea2e28ea4de4d2143c5c2b7984a34773.png

为增强不同时间步表示的辨别能力,文章将网络级表示和区域级表示作为嵌入对618fbfd7adea990e2debb5942db9a1b9.png,其中若区域级表示和网络级表示为同一个时间步则为正值,反之,则为负值。最后,基于上述定义,时间异质性建模的损失函数为交叉熵损失函数,定义如下。

c402e75e316d6cb0911bd99b68d471ed.png

综上所述,模型的整体损失函数定义如下:

8a43fefe4db2e57dd2bb8f9d88bf07cf.png

4bfd43cb7e037bc4cf5534a0fcbe40fe.png

实验

文章在几个真实数据集上进行一系列实验,以评估ST-SSL的性能。数据集包括纽约Bike数据集和Taxi数据集,以及北京出租车数据集。这些实验旨在回答以下研究问题:

  • 问题1:与各种基线相比,ST-SSL的整体流量预测性能如何?

  • 问题2:设计的不同子模块对模型性能的贡献是什么?

  • 问题3:对于异构空间区域和不同时间段,ST-SSL的性能如何

  • 问题4:增广图和学习表征如何使模型受益?

问题1:在不同数据集上的实验结果如下图所示。可以看到ST-SSL的预测误差最低。

d60b0df1a17ab8e474bbbb13c3bd00c5.png

进一步,文章对所提出模型在不同区域的预测误差进行可视化,并比对了不同基线模型的预测误差,如下图所示。

1e8e06fbb747f6323ebe9419389a3288.png

问题2:为验证所提出模型不同子模块的影响,文章构建了四组模型的变体进行消融实验。具体而言,ST-SSL-sa表示该模型用随机边缘去除和增加的方式取代了图拓扑上的异构引导结构增广;ST-SSL-ta表示该模型使用随机交通量掩膜替换原有的基于异质性引导的流量增广;ST-SSL-sh表示该模型不使用空间异质建模模块;ST-SSL-th表示该模型不使用时间异质性建模模块。实验结果如下。

5a2c1b8b5378588e2eb2676ef4ea4201.png

问题3:为探究ST-SSL的鲁棒性,文章在北京出租车数据集上对具有异构数据分布的空间区域和具有不同模式的时间段进行了流量预测。对于空间异质性而言,文章利用历史交通数据的统计量,例如均值、中位数、标准差,将不同区域进行聚类。下图分别展示了不同区域的划分结果以及预测结果。

c8be4d3be5c3f81303565df0ee8a2735.png

对于时间异质性而言,文章将工作日分为四个时段,将节假日分为2各时段。下图分别展示了划分方法和预测结果。

ccdf9d2252226546bc17a8ff0e6864d1.png

问题4:文章通过定性分析的方法进行分析,在北京出租车数据集上进行实验。结果如下图所示。文章所提出的方法自适应地去除了具有异构交通模式的相邻区域之间的连接。同时,在城市潜在功能相似的遥远区域之间建立联系。通过这种方式,ST-SSL不仅可以消除低相互关联交通模式的区域连接,还可以捕获全球城市背景下的长期区域依赖关系

349b65c77fcee4901a12f6fc2d10e502.png

此外,为了进一步探究ST-SSL中的嵌入是如何提升预测精度的,文章对比了AGCRN和ST-SSL的预测结果,通过T-SNE方法进行可视化,如下图所示。

52564bd2ce9493bc8c1cf03bca6ef2c1.png

结论

文章提出一种新的时空自监督学习(ST-SSL)框架以解决交通预测问题。具体而言,文章整合了时间和空间卷积来编码时空交通模式。进一步,文章设计两个主要模块:1、一个由自适应图增强和基于聚类的生成任务组成的空间自监督学习范式;2、一个依赖于时间感知的对比任务的时间自监督学习范式,以空间和时间异质性感知的自监督信号补充主要的交通流量预测任务。在4个交通流数据集上的综合实验证明了ST-SSL算法的鲁棒性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/159853.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

香港学界呼吁RWA“在港先发”,构建基于港元稳定币的Web3生态!

2023年以来,市场对于RWA(Real World Assets)即真实世界资产“代币化”的讨论愈发频繁,一些观点认为 RWA将在下一轮加密资产牛市中成为焦点,部分Web3创业者和传统金融企业也快速将业务方向瞄准相关赛道,而被…

架构-设计原则

1、面向对象的SOLID 1.1 概述 SOLID是5个设计原则开头字母的缩写,其本身就有“稳定的”的意思,寓意是“遵从SOLID原则可以建立稳定、灵活、健壮的系统”。5个原则分别如下: Single Responsibility Principle(SRP)&am…

grafana api创建dashboard 记录

文章目录 json model导入申请api key创建dashboard删除dashboard json model导入 直接在ui通过json model 导入,开发自己用还好,但对非开发人员不太友好,故考虑通过api后台自动创建 api doc : https://grafana.com/docs/grafana/v9.3/devel…

R实现动态条件相关模型与GARCH模型结合研究中美股市动态相关性(DCC-GARCH模型)

大家好,我是带我去滑雪! 中美两国是全球最大的经济体,其经济活动对全球产业链和贸易体系都具有巨大影响。中美之间的经济互动包括大规模的贸易、投资和金融往来。这些互动不仅仅反映在经济数据上,还体现在股市上。中美股市的联动关…

吃瓜教程-模型的评估与选择

在训练集上的误差称为训练误差(training error)或经验误差(empirical error)。在测试集上的误差称为测试误差(test error)。学习器在所有新样本上的误差称为泛化误差(generalization error&…

drawio简介以及下载安装

drawio简介以及下载安装 drawio是一款非常强大的开源在线的流程图编辑器,支持绘制各种形式的图表,提供了 Web端与客户端支持,同时也支持多种资源类型的导出。 访问网址:draw.io或者直接使用app.diagrams.net直接打开可以使用在线版…

PyTorch 深度学习之处理多维特征的输入Multiple Dimension Input(六)

1.Multiple Dimension Logistic Regression Model 1.1 Mini-Batch (N samples) 8D->1D 8D->2D 8D->6D 1.2 Neural Network 学习能力太好也不行(学习到的是数据集中的噪声),最好的是要泛化能力,超参数尝试 Example, Arti…

软件工程与计算总结(九)软件体系结构基础

目录 ​编辑 一.体系结构的发展 二.理解体系结构 1.定义 2.区分体系结构的抽象与实现 3.部件 4.连接件 5.配置 三.体系结构风格初步 1.主程序/子程序 2.面向对象式 3.分层 4.MVC 一.体系结构的发展 小规模编程的重点在于模块内部的程序结构非常依赖于程序设计语言…

仪酷LabVIEW OD实战(3)——Object Detection+onnx工具包快速实现yolo目标检测

‍‍🏡博客主页: virobotics(仪酷智能):LabVIEW深度学习、人工智能博主 🎄所属专栏:『LabVIEW深度学习工具包』『仪酷LabVIEW目标检测工具包实战』 📑上期文章:『仪酷LabVIEW OD实战(2)——Obje…

E047-论坛漏洞分析及利用-针对Wordpress论坛进行信息收集与漏洞扫描的探索

任务实施: E047-论坛漏洞分析及利用-针对Wordpress论坛进行信息收集与漏洞扫描的探索 任务环境说明: 服务器场景:p9_kali-6(用户名:root;密码:toor) 服务器场景操作系统:Kali Li…

MPNN 模型:GNN 传递规则的实现

首先,假如我们定义一个极简的传递规则 A是邻接矩阵,X是特征矩阵, 其物理意义就是 通过矩阵乘法操作,批量把图中的相邻节点汇聚到当前节点。 但是由于A的对角线都是 0.因此自身的节点特征会被过滤掉。 图神经网络的核心是 吸周围…

mysql中的几种排名函数

mysql中的排名函数 mysql里面的排名函数&#xff0c;涉及有以下几个&#xff1a; rank()、dense_rank()、row_number() 1、rank() 函数 RANK() OVER (PARTITION BY <expression>[{,<expression>...}]ORDER BY <expression> [ASC|DESC], [{,<expression…

MySQL有时候命中索引有时候又不命中

索引失效的情况 -----可能 索引主要看where 、group by 、order by 1.组合索引不遵循最佳左前缀法制。最佳左前缀法制&#xff1a;如果索引了多列&#xff0c;要遵循最左前缀法则&#xff0c;指的是查询从索引的最左前列开始并且不跳过索引中的列。如组合索引为A B C 只有ABC,A…

C# RestoreFormer 图像修复

效果 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.Drawing; using System.Drawing.Imaging; using System.Windows.Forms;namespace 图像修复 {pu…

【SpringCloud-10】SCA-nacos

前言&#xff1a; 前面介绍的springcloud&#xff0c;可以看做第一代&#xff0c;称为&#xff1a;SCN&#xff08;spring cloud Netflix&#xff09;; 接下来介绍的是第二代&#xff1a;SCA&#xff08;spring cloud alibaba&#xff09;&#xff1b; SCA主要有以下组件&#…

Java|学习|异常

1.异常 1.1 异常 1.1.1 概述 异常&#xff1a;就是程序出现了不正常的情况。 Error&#xff1a;严重问题&#xff0c;不需要处理。 Exception&#xff1a;称为异常类&#xff0c;它表示程序本身可以处理的问题。 RuntimeException&#xff1a;在编译器不检查&#xff0c;出…

关于Skywalking Agent customize-enhance-trace对应用复杂参数类型取值

对于Skywalking Agent customize-enhance-trace 大家应该不陌生了&#xff0c;主要支持以非入侵的方式按用户自定义的Span跟踪对应的应用方法&#xff0c;并获取数据。 参考https://skywalking.apache.org/docs/skywalking-java/v9.0.0/en/setup/service-agent/java-agent/cust…

论文阅读:Rethinking Range View Representation for LiDAR Segmentation

来源ICCV2023 0、摘要 LiDAR分割对于自动驾驶感知至关重要。最近的趋势有利于基于点或体素的方法&#xff0c;因为它们通常产生比传统的距离视图表示更好的性能。在这项工作中&#xff0c;我们揭示了建立强大的距离视图模型的几个关键因素。我们观察到&#xff0c;“多对一”…

TCP/IP(九)TCP的连接管理(六)TIME_WAIT状态探究

一 TIME_WAIT探究 要明确TIME_WAIT状态在tcp四次挥手的阶段 ① 为什么 TIME_WAIT 等待的时间是 2MSL? 背景&#xff1a; 客户端在收到服务端第三次FIN挥手后,就会进入TIME_WAIT 状态,开启时长为2MSL的定时器1、MSL 是 Maximum Segment Lifetime 报文最大生存时间2、2MSL…

论文阅读之【Is GPT-4 a Good Data Analyst?(GPT-4是否是一位好的数据分析师)】

文章目录 论文阅读之【Is GPT-4 a Good Data Analyst?&#xff08;GPT-4是否是一位好的数据分析师&#xff09;】背景&#xff1a;数据分析师工作范围基于GPT-4的端到端数据分析框架将GPT-4作为数据分析师的框架的流程图 实验分析评估指标表1&#xff1a;GPT-4性能表现表2&…