SOFTS: 时间序列预测的最新模型以及Python使用示例

近年来,深度学习一直在时间序列预测中追赶着提升树模型,其中新的架构已经逐渐为最先进的性能设定了新的标准。

这一切都始于2020年的N-BEATS,然后是2022年的NHITS。2023年,PatchTST和TSMixer被提出,最近的iTransformer进一步提高了深度学习预测模型的性能。

这是2024年4月《SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion》中提出的新模型,采用集中策略来学习不同序列之间的交互,从而在多变量预测任务中获得最先进的性能。

在本文中,我们详细探讨了SOFTS的体系结构,并介绍新的STar聚合调度(STAD)模块,该模块负责学习时间序列之间的交互。然后,我们测试将该模型应用于单变量和多变量预测场景,并与其他模型作为对比。

SOFTS介绍

SOFTS是 Series-cOre Fused Time Series的缩写,背后的动机来自于长期多元预测对决策至关重要的认识:

首先我们一直研究Transformer的模型,它们试图通过使用补丁嵌入和通道独立等技术(如PatchTST)来降低Transformer的复杂性。但是由于通道独立性,消除了每个序列之间的相互作用,因此可能会忽略预测信息。

iTransformer 通过嵌入整个序列部分地解决了这个问题,并通过注意机制处理它们。但是基于transformer的模型在计算上是复杂的,并且需要更多的时间来训练非常大的数据集。

另一方面有一些基于mlp的模型。这些模型通常很快,并产生非常强的结果,但当存在许多序列时,它们的性能往往会下降。

所以出现了SOFTS:研究人员建议使用基于mlp的STAD模块。由于是基于MLP的,所以训练速度很快。并且STAD模块,它允许学习每个序列之间的关系,就像注意力机制一样,但计算效率更高。

SOFTS架构

在上图中可以看到每个序列都是单独嵌入的,就像在iTransformer 中一样。

然后将嵌入发送到STAD模块。每个序列之间的交互都是集中学习的,然后再分配到各个系列并融合在一起。

最后再通过线性层产生预测。

这个体系结构中有很多东西需要分析,我们下面更详细地研究每个组件。

1、归一化与嵌入

首先使用归一化来校准输入序列的分布。使用了可逆实例的归一化(RevIn)。它将数据以单位方差的平均值为中心。然后每个系列分别进行嵌入,就像在iTransformer 模型。

在上图中我们可以看到,嵌入整个序列就像应用补丁嵌入,其中补丁长度等于输入序列的长度。

这样,嵌入就包含了整个序列在所有时间步长的信息。

然后将嵌入式系列发送到STAD模块。

2、STar Aggregate-Dispatch (STAD)

STAD模块是soft模型与其他预测方法的真正区别。使用集中式策略来查找所有时间序列之间的相互作用。

嵌入的序列首先通过MLP和池化层,然后将这个学习到的表示连接起来形成核(上图中的黄色块表示)。

核构建好了以后就进入了“重复”和“连接”的步骤,在这个步骤中,核表示被分派给每个系列。

MLP和池化层未捕获的信息还可以通过残差连接添加到核表示中。然后在融合(fuse)操作的过程中,核表示及其对应系列的残差都通过MLP层发送。最后的线性层采用STAD模块的输出来生成每个序列的最终预测。

与其他捕获通道交互的方法(如注意力机制)相比,STAD模块的主要优点之一是它降低了复杂性。

因为STAD模块具有线性复杂度,而注意力机制具有二次复杂度,这意味着STAD在技术上可以更有效地处理具有多个序列的大型数据集。

下面我们来实际使用SOFTS进行单变量和多变量场景的测试。

使用SOFTS预测

这里,我们使用 Electricity Transformer dataset 数据集。

这个数据集跟踪了中国某省两个地区的变压器油温。每小时和每15分钟采样一个数据集,总共有四个数据集。

我门使用neuralforecast库中的SOFTS实现,这是官方认可的库,并且这样我们可以直接使用和测试不同预测模型的进行对比。

在撰写本文时,SOFTS还没有集成在的neuralforecast版本中,所以我们需要使用源代码进行安装。

 pip install git+https://github.com/Nixtla/neuralforecast.git

然后就是从导入包开始。使用datasetsforecast以所需格式加载数据集,以便使用neuralforecast训练模型,并使用utilsforecast评估模型的性能。这就是我们使用neuralforecast的原因,因为他都是一套的

 import pandas as pdimport numpy as npimport matplotlib.pyplot as pltfrom datasetsforecast.long_horizon import LongHorizonfrom neuralforecast.core import NeuralForecastfrom neuralforecast.losses.pytorch import MAE, MSEfrom neuralforecast.models import SOFTS, PatchTST, TSMixer, iTransformerfrom utilsforecast.losses import mae, msefrom utilsforecast.evaluation import evaluate

编写一个函数来帮助加载数据集,以及它们的标准测试大小、验证大小和频率。

 def load_data(name):if name == "ettm1":Y_df, *_ = LongHorizon.load(directory='./', group='ETTm1')Y_df = Y_df[Y_df['unique_id'] == 'OT'] # univariate datasetY_df['ds'] = pd.to_datetime(Y_df['ds'])val_size = 11520test_size = 11520freq = '15T'elif name == "ettm2":Y_df, *_ = LongHorizon.load(directory='./', group='ETTm2')Y_df['ds'] = pd.to_datetime(Y_df['ds']) val_size = 11520test_size = 11520freq = '15T'return Y_df, val_size, test_size, freq

然后就可以对ETTm1数据集进行单变量预测。

1、单变量预测

加载ETTm1数据集,将预测范围设置为96个时间步长。

可以测试更多的预测长度,但我们这里只使用96。

 Y_df, val_size, test_size, freq = load_data('ettm1')horizon = 96

然后初始化不同的模型,我们将soft与TSMixer, iTransformer和PatchTST进行比较。

所有模型都使用的默认配置将最大训练步数设置为1000,如果三次后验证损失没有改善,则停止训练。

 models = [SOFTS(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),TSMixer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),iTransformer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),PatchTST(h=horizon, input_size=3*horizon, max_steps=1000, early_stop_patience_steps=3)]

然后初始化NeuralForecast对象训练模型。并使用交叉验证来获得多个预测窗口,更好地评估每个模型的性能。

 nf = NeuralForecast(models=models, freq=freq)nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)nf_preds = nf_preds.reset_index()

评估计算了每个模型的平均绝对误差(MAE)和均方误差(MSE)。因为之前的数据是缩放的,因此报告的指标也是缩放的。

 ettm1_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer', 'PatchTST'])

从上图可以看出,PatchTST的MAE最低,而softts、TSMixer和PatchTST的MSE是一样的。在这种特殊情况下,PatchTST仍然是总体上最好的模型。

这并不奇怪,因为PatchTST在这个数据集中是出了名的好,特别是对于单变量任务。下面我们开始测试多变量场景。

2、多变量预测

使用相同的load_data函数,我们现在为这个多变量场景使用ETTm2数据集。

 Y_df, val_size, test_size, freq = load_data('ettm2')horizon = 96

然后简单地初始化每个模型。我们只使用多变量模型来学习序列之间的相互作用,所以不会使用PatchTST,因为它应用通道独立性(意味着每个序列被单独处理)。

然后保留了与单变量场景中相同的超参数。只将n_series更改为7,因为有7个时间序列相互作用。

 models = [SOFTS(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),TSMixer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),iTransformer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE())]

训练所有的模型并进行预测。

 nf = NeuralForecast(models=models, freq='15min')nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)nf_preds = nf_preds.reset_index()

最后使用MAE和MSE来评估每个模型的性能。

 ettm2_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer'])

上图中可以看到到当在96的水平上预测时,TSMixer large在ETTm2数据集上的表现优于iTransformer和soft。

虽然这与soft论文的结果相矛盾,这是因为我们没有进行超参数优化,并且使用了96个时间步长的固定范围。

这个实验的结果可能不太令人印象深刻,我们只在固定预测范围的单个数据集上进行了测试,所以这不是SOFTS性能的稳健基准,同时也说明了SOFTS在使用时可能需要更多的时间来进行超参数的优化。

总结

SOFTS是一个很有前途的基于mlp的多元预测模型,STAD模块是一种集中式方法,用于学习时间序列之间的相互作用,其计算强度低于注意力机制。这使得模型能够有效地处理具有许多并发时间序列的大型数据集。

虽然在我们的实验中,SOFTS的性能可能看起来有点平淡无奇,但请记住,这并不代表其性能的稳健基准,因为我们只在固定视界的单个数据集上进行了测试。

但是SOFTS的思路还是非常好的,比如使用集中式学习时间序列之间的相互作用,并且使用低强度的计算来保证数据计算的效率,这都是值得我们学习的地方。

并且每个问题都需要其独特的解决方案,所以将SOFTS作为特定场景的一个测试选项是一个明智的选择。

SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion

https://avoid.overfit.cn/post/6254097fd18d479ba7fd85efcc49abac

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/350058.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

显著提高iOS应用中Web页面的加载速度 - 提前下载页面的关键资源(如JavaScript、CSS和图像)

手动下载并缓存资源是一种有效的方式,可以确保在需要时资源已经在本地存储,这样可以显著提高加载速度。 缓存整个 web 页面的所有资源文件 具体实现步骤 下载和缓存资源:包括 HTML 文件、CSS、JavaScript 和图像。在应用启动时预加载资源。…

WinForm之TCP客户端通讯

目录 一 设计界面 二 后台代码 一 设计界面 二 后台代码 using System.Net.Sockets; using System.Text;namespace TCP网络客户端通讯 {public partial class Form1 : Form{public Form1(){InitializeComponent();}TcpClient tcpClient new TcpClient();private void conne…

停止游戏中的循环扣血显示

停止游戏中循环扣血并显示的具体实现方式会依赖于你的代码结构和游戏的逻辑。通常情况下,你可以通过以下方式来实现停止循环扣血和显示: 1、问题背景 在使用 Python 代码为游戏开发一个生命值条时,遇到了一个问题。代码使用了循环来减少生命…

atcoder abc357

A Sanitize Hands 问题&#xff1a; 思路&#xff1a;前缀和&#xff0c;暴力&#xff0c;你想咋做就咋做 代码&#xff1a; #include <iostream>using namespace std;const int N 2e5 10;int n, m; int a[N];int main() {cin >> n >> m;for(int i 1…

四轴飞行器、无人机(STM32、NRF24L01)

一、简介 此电路由STM32为主控芯片&#xff0c;NRF24L01、MPU6050为辅,当接受到信号时&#xff0c;处理对应的指令。 二、实物图 三、部分代码 void FlightPidControl(float dt) { volatile static uint8_t statusWAITING_1; switch(status) { case WAITING_1: //等待解锁 if…

2024最新最全【AIGC】学习零基础入门到精通,看完这一篇就够了!

这个文案就是由AI生成的哦&#xff01;&#xff01;&#xff01;&#xff01; AIGC&#xff08;AI-Generated Content&#xff09; 即人工智能生成内容&#xff0c;是指利用人工智能技术来创造各种形式的内容&#xff0c;包括文字、图像、视频、音频和游戏等。与专业生成内容…

图解Transformer学习笔记

教程是来自https://github.com/datawhalechina/learn-nlp-with-transformers/blob/main/docs/ 图解Transformer Attention为RNN带来了优点&#xff0c;那么有没有一种神经网络结构直接基于Attention构造&#xff0c;而不再依赖RNN、LSTM或者CNN的结构&#xff0c;这就是Trans…

【算法专题--链表】反转链表II--高频面试题(图文详解,小白一看就会!!!)

目录 一、前言 二、题目描述 三、解题方法 ⭐迭代法 --- 带哨兵位&#xff08;头节点&#xff09; &#x1f95d; 什么是哨兵位头节点&#xff1f; &#x1f34d; 解题思路 四、总结与提炼 五、共勉 一、前言 反转链表II这道题&#xff0c;可以说是--链表专题--&am…

RAG工作流在高效信息检索中的应用

介绍 RAG&#xff08;Retrieval Augmented Generation&#xff09;是一种突破知识限制、整合外部数据并增强上下文理解的方法。 由于其高效地整合外部数据而无需持续微调&#xff0c;RAG的受欢迎程度正在飙升。 让我们来探索RAG如何克服LLM的挑战&#xff01; LLM知识限制大…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第38课-密室逃脱-3D互动剧情

【WEB前端2024】3D智体编程&#xff1a;乔布斯3D纪念馆-第38课-密室逃脱 使用dtns.network德塔世界&#xff08;开源的智体世界引擎&#xff09;&#xff0c;策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智体世界引擎&…

Flutter - Material3适配

demo 地址: https://github.com/iotjin/jh_flutter_demo 代码不定时更新&#xff0c;请前往github查看最新代码 Flutter - Material3适配 对比图具体实现一些组件的变化 代码实现Material2的ThemeDataMaterial3的ThemeData Material3适配官方文档 flutter SDK升级到3.16.0之后 …

C# WinForm —— 35 StatusStrip 介绍

1. 简介 状态栏 StatusStrip&#xff0c;默认在软件的最下方&#xff0c;用于显示系统时间、版本、进度条、账号、角色信息、操作位置信息等 可以在状态栏中添加的控件类型有&#xff1a;StatusLabel、ProgressBar、DropDownButton、SplitButton 2. 属性 属性解释(Name)控…

utm投影

一 概述 UTM (Universal Transverse Mercator)坐标系是由美国军方在1947提出的。虽然我们仍然将其看作与“高斯&#xff0d;克吕格”相似的坐标系统&#xff0c;但实际上UTM采用了网格的分带&#xff08;或分块&#xff09;。除在美国本土采用Clarke 1866椭球体以外&#xff0c…

树莓派等Linux开发板上使用 SSD1306 OLED 屏幕,bullseye系统 ubuntu,debian

Raspberry Pi OS Bullseye 最近发布了,随之而来的是许多改进,但其中大部分都在引擎盖下。没有那么多视觉差异,最明显的可能是新的默认桌面背景,现在是大坝或湖泊上的日落。https://www.the-diy-life.com/add-an-oled-stats-display-to-raspberry-pi-os-bullseye/ 通过这次操…

【GD32】 TIMER通用定时器学习+PWM输出占空比控制LED

扩展&#xff1a;对PWM波形的输出进行捕获 目录 一、简介二、具体功能描述1、时钟源的选择&#xff1a;2、预分频器&#xff1a;3、计数模式&#xff1a;向上计数模式&#xff1a;向下计数模式&#xff1a;中央对齐模式&#xff1a; 4、捕获/比较通道 输入捕获模式 输出比…

前端问题整理

Vue vue mvvm&#xff08;Model-View-ViewModel&#xff09;架构模式原理 Model 是数据层&#xff0c;即 vue 实例中的数据View 是视图层&#xff0c; 即 domViewModel&#xff0c;即连接Model和Vue的中间层&#xff0c;Vue实例就是ViewModelViewModel 负责将 Model 的变化反映…

TCGAbiolinks包学习

TCGAbiolinks 写在前面学习目的GDCquery GDCdownload GDC prepare中间遇到的报错下载蛋白质数据 写在前面 由于别人提醒我TCGA的数据可以利用TCGAbiolinks下载并处理&#xff0c;所以我决定阅读该包手册&#xff0c;主要是该包应该是有更新的&#xff0c;我看手册进行更新了&…

【CS.PL】Lua 编程之道: 简介与环境设置 - 进度8%

1 初级阶段 —— 简介与环境设置 文章目录 1 初级阶段 —— 简介与环境设置1.1 什么是 Lua&#xff1f;特点?1.2 Lua 的应用领域1.3 安装 Lua 解释器1.3.1 安装1.3.2 Lua解释器的结构 1.4 Lua执行方式1.4.0 程序段1.4.1 使用 Lua REPL&#xff08;Read-Eval-Print Loop&#x…

LAMP部署及应用

LAMP架构 LAMP架构是指一种常用的网站开发架构&#xff0c;它由以下几个组件组成&#xff1a; Linux操作系统&#xff1a;作为服务器的操作系统&#xff0c;LAMP架构通常使用Linux作为操作系统&#xff0c;因为Linux通常被认为是稳定和安全的。 Apache HTTP服务器&#xff1a…

iOS ReactiveCocoa MVVM

学习了在MVVM中如何使用RactiveCocoa&#xff0c;简单的写上一个demo。重点在于如何在MVVM各层之间使用RAC的信号来更方便的在各个层之间进行响应式数据交互。 demo需求&#xff1a;一个登录界面(登录界面只有账号和密码都有输入&#xff0c;登录按钮才可以点击操作)&#xff0…