时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

官方论文地址->官方论文地址

官方代码地址->官方代码地址

个人修改代码->个人修改的代码已经上传CSDN免费下载

一、本文介绍

本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测(值得一提的是DLinear的出现是为了挑战Transformer在实现序列预测中有效性)本文的讲解内容包括:模型原理、数据集介绍、参数讲解、模型训练和预测、结果可视化、训练个人数据集,讲解顺序如下->

预测类型->单元预测、多元预测

适用对象->如果你的配置不是很好这个模型应该很适合你因为参数量很小训练速度很快

二、模型原理

DLinear模型出现是为了调整Transformer的有效性从而存在,Transformer的设计都十分的复杂和需要大量的参数,所以作者提出了一种简单的结构DLinear(参数量我实验过程中确实非常小)

DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测。

具体地,DLinear如何工作的关键点如下

  1. 时间序列分解:DLinear将输入的时间序列分解为两部分——趋势部分和剩余部分。这种分解有助于分别处理时间序列中的长期趋势和短期波动。

  2. 单层线性网络:对于趋势和剩余序列,DLinear分别使用两个单层的线性网络进行建模。这种简单的架构使得DLinear在处理时间序列时既高效又有效。

  3. 预测任务:在进行预测时,DLinear结合这两个网络的输出来生成最终的时间序列预测。

总结->可以看出DLinear的核心结构真的十分简单就包括一个分解和两个线性网络进行建模最后经过一个简单的相加就输出了结果。

模型的网络结构图如下所示->

图片分析->可以看到和我们上面讲的一样,数据从输入进来经过两个分支,一个为趋势性一个为剩余序列,然后分别经过一个线性层处理(这里的提到的线性层就是普通的全连接层),然后将结果进行简单的拼接就完成了结果的输出(这就这样的简单模型结果比过程十分复杂的Transformer模型效果要好->我自己实验效果确实要好,我拿2020年的bestpaper和普通的Transformer都进行了对比效果确实要有提升)。

下面的图片是一个简单的线性层(普通的全连接层)提取数据的过程图->

这里把模型的代码结构放出来方便大家根据讲解和代码进行对比。

class moving_avg(nn.Module):"""Moving average block to highlight the trend of time series"""def __init__(self, kernel_size, stride):super(moving_avg, self).__init__()self.kernel_size = kernel_sizeself.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)def forward(self, x):# padding on the both ends of time seriesfront = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)x = torch.cat([front, x, end], dim=1)x = self.avg(x.permute(0, 2, 1))x = x.permute(0, 2, 1)return xclass series_decomp(nn.Module):"""Series decomposition block"""def __init__(self, kernel_size):super(series_decomp, self).__init__()self.moving_avg = moving_avg(kernel_size, stride=1)def forward(self, x):moving_mean = self.moving_avg(x)res = x - moving_meanreturn res, moving_meanclass Model(nn.Module):"""Decomposition-Linear"""def __init__(self, configs):super(Model, self).__init__()self.seq_len = configs.seq_lenself.pred_len = configs.pred_len# Decompsition Kernel Sizekernel_size = 25self.decompsition = series_decomp(kernel_size)self.individual = configs.individualself.channels = configs.enc_inif self.individual:self.Linear_Seasonal = nn.ModuleList()self.Linear_Trend = nn.ModuleList()for i in range(self.channels):self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))# Use this two lines if you want to visualize the weights# self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))# self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))else:self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)# Use this two lines if you want to visualize the weights# self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))# self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))def forward(self, x):# x: [Batch, Input length, Channel]seasonal_init, trend_init = self.decompsition(x)seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)if self.individual:seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)for i in range(self.channels):seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])else:seasonal_output = self.Linear_Seasonal(seasonal_init)trend_output = self.Linear_Trend(trend_init)x = seasonal_output + trend_outputreturn x.permute(0,2,1) # to [Batch, Output length, Channel]

我看论文的内容大比分都是对比实验,因为DLinear的产生就是为了质疑Transformer所以他和各种Transformer的模型进行对比试验,因为本篇文章就是DLinear的实战案例,对比的部分我就不讲了,大家有兴趣可以看看论文内容在最上面我已经提供了链接。 

三、数据集介绍

所用到的数据集为某公司的业务水平评估和其它参数具体的内容我就介绍了估计大家都是想用自己的数据进行训练模型,这里展示部分图片给大家提供参考。

四、参数讲解

模型的参数如下(大部分都是一些公共参数并不涉及模型)->

parser = argparse.ArgumentParser(description='DLinearNet Multivariate Time Series Forecasting')# basic configparser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')parser.add_argument('--show_results', type=bool, default=True, help='Whether show forecast and real results graph')parser.add_argument('--model', type=str, default='SCINet',help='Model name')# data loaderparser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')parser.add_argument('--features', type=str, default='MS',help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')parser.add_argument('--freq', type=str, default='h',help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')# forecasting taskparser.add_argument('--seq_len', type=int, default=126, help='input sequence length')parser.add_argument('--label_len', type=int, default=64, help='start token length')parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')# modelparser.add_argument('--individual', action='store_true', default=False,help='DLinear: a linear layer for each variate(channel) individually')parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')parser.add_argument('--c_out', type=int, default=1, help='output size')parser.add_argument('--dropout', type=float, default=0.05, help='dropout')parser.add_argument('--embed', type=str, default='timeF',help='time features encoding, options:[timeF, fixed, learned]')parser.add_argument('--activation', type=str, default='gelu', help='activation')# optimizationparser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')parser.add_argument('--loss', type=str, default='mse', help='loss function')parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')# GPUparser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下(如果你想训练你自己的数据集可以仔细看看)->

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3show_resultsbool是否保存预测值和真实值的对比
4modelstr定义的模型名称
5root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
6data_pathstr这个填写你文件的具体名称。
7featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
8targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
9freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
10checkpointsstr训练出来的模型保存路径
11seq_lenint用过去的多少条数据来预测未来的数据
12label_lenint可以理解为更高的权重占比的部分要小于seq_len
13pred_lenint预测未来多少个时间点的数据
14enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
15dec_inint同上
16individualbool这个就是我们上面提到的两个线性层,如果为True我们则对每一个通道用单独的线性层处理,False则为所有的通道用一个线性层
17c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
18dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
19embedstr时间特征的编码方式,默认为"timeF"
20activationstr激活函数
21num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
22train_epochsint训练的次数
23batch_sizeint一次往模型力输入多少条数据
24learning_ratefloat学习率。
25lossstr     损失函数,默认为"mse"
26lradjstr     学习率的调整方式,默认为"type1"
27use_gpubool是否使用GPU训练,根据自身来选择
28gpuintGPU的编号

五、模型训练和预测

1.项目目录结构

项目的目录构造如下->

其中data为训练用的数据放的地方,layers为模型结构存放的地方,models为训练保存的训练模型,results为可视化结果保存的图片和滚动预测的结果,util为一些工具。 

2.模型训练

当我们经过上面的参数讲解之后,我们可以开始训练模型了,控制台输出如下->

3.滚动预测 

这里进行滚动预测的控制台输出->

4.结果展示 

运行结果后,结果保存到同级目录下(下图为预测值和真实值的对比)-> 

5.结果分析 

可以看到预测值和真实值之间的差距还可以,但是这个模型的参数量少得可怜,不得不得质疑Transformer模型的有效性~

六、训练你个人数据集

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结

到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

概念理解 

15种时间序列预测方法总结(包含多种方法代码实现)

数据分析

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

机器学习——难度等级(⭐⭐)

时间序列预测实战(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

深度学习——难度等级(⭐⭐⭐⭐)

时间序列预测实战(五)基于Bi-LSTM横向搭配LSTM进行回归问题解决

时间序列预测实战(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测实战(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

Transformer——难度等级(⭐⭐⭐⭐)

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(一)深度学习华为MTS-Mixers模型

个人创新模型——难度等级(⭐⭐⭐⭐⭐)

时间序列预测实战(十)(CNN-GRU-LSTM)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

传统的时间序列预测模型(⭐⭐)

时间序列预测实战(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

时间序列预测实战(六)深入理解ARIMA包括差分和相关性分析

融合模型——难度等级(⭐⭐⭐)

时间序列预测实战(九)PyTorch实现融合移动平均和LSTM-ARIMA进行长期预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/188607.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ansible自动化运维工具及模块

目录 一、Ansible 1.ansible简介 2、ansible的特性 二、ansible的部署 1)管理端安装ansible 2)配置主机清单 3)配置密钥对验证 三、ansible命令块模块 1)command模块 2)shell模块 3)cron模块 4)…

Jdk 1.8 for mac 详细安装教程(含版本切换)

Jdk 1.8 for mac 详细安装教程(含版本切换) 官网下载链接 https://www.oracle.com/cn/java/technologies/downloads/#java8-mac 一、选择我们需要安装的jdk版本,这里以jdk8为例,下载 macOS 版本,M芯片下载ARM64版本…

数据结构之双向链表

目录 引言 链表的分类 双向链表的结构 双向链表的实现 定义 创建新节点 初始化 打印 尾插 头插 判断链表是否为空 尾删 头删 查找与修改 指定插入 指定删除 销毁 顺序表和双向链表的优缺点分析 源代码 dlist.h dlist.c test.c 引言 数据结构…

网络通信TCP、UDP详解

目录 IP 和端口 网络传输中的 2 个对象:server 和 client 两种传输方式:TCP/UDP TCP 和 UDP 原理上的区别 为何存在 UDP 协议 TCP/UDP 网络通信大概交互图 IP 和端口 所有的数据传输,都有三个要素 :源、目的、长度。 怎么表…

ZYNQ_project:IP_ram_pll_test

例化MMCM ip核,产生100Mhz,100Mhz并相位偏移180,50Mhz,25Mhz的时钟信号。 例化单口ram,并编写读写控制器,实现32个数据的写入与读出。 模块框图: 代码: module ip_top(input …

基于FPGA的PS端的Si5340的控制

1、功能 Si5340/41-D可以输出任意频率,当然有范围,100Hz1GHz。外部输入为24M或者4854M的XTAL,VCO在13500~14256Mhz之间,控制接口采用IIC或者SPI。 芯片架构图 2、IIC控制方式 3、直接上控制代码 使用米联客ZU3EG,将…

git使用笔记

0.记录使用经验 1.提交和push代码 git add .添加修改 git commit -m "提交日志" git push origin branch_name推送分支名称代码到远程服务器对应分支 1.1日常操作 git status查看仓库状态 git branch查看分支 git branch -a查看所有分支【包含远程】 git checkou…

如何从存档服务器上完全删除PDM用户

当创建新用户时使用“PDM 登录”类型(如下图),PDM用户名和密码会存储于存档服务器的注册表中。 存档服务器的注册表位置如下: HKEY_LOCAL_MACHINE\SOFTWARE\SolidWorks\Applications\PDMWorks Enterprise\ArchiveServer\ConisioU…

在 Microsoft Word 中启用护眼模式

在 Microsoft Word 中启用护眼模式 在使用 Microsoft Word 365 或 Word 2019(Windows)版本时,启用护眼模式(也称为“夜间模式”)可以有效减轻屏幕亮度,有助于减少眼睛疲劳。以下是启用护眼模式的步骤&…

Linux centos系统中添加磁盘

为了学习与训练文件系统或磁盘的分区、格式化和挂载/卸载,我们需要为虚拟机添加磁盘。根据需要,可以添加多块不同大小的磁盘。具体操作讨论如下,供参考。 一、添加 1.开机前 有两个地方,可选择打开添加硬盘对话框 (1)双击左侧…

深度学习模型基于Python+TensorFlow+Django的垃圾识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 要使用Python、TensorFlow和Django构建一个垃圾识别系统,您可以按照以下步骤进行操作: 安装…

Learn runqlat in 5 minutes

内容预告 learn X in 5 系列第一篇. 本篇主要介绍进程时延统计方式和 rawtracepoint. runqlat "高负载场景下应用为何卡顿", "进程 A 为什么得不到调度". 当我们在工作生活中产生这样的疑问, 目标进程的调度时延是一个不错的观测切入点. runqlat 可以帮…

2022最新版-李宏毅机器学习深度学习课程-P50 BERT的预训练和微调

模型输入无标签文本(Text without annotation),通过消耗大量计算资源预训练(Pre-train)得到一个可以读懂文本的模型,在遇到有监督的任务是微调(Fine-tune)即可。 最具代表性是BERT&…

在线生成二维码--支持彩色二维码和包含Logo

具体请前往:在线二维码生成工具--可将网址等内容生成为指定大小,指定颜色的彩色二维码,同时支持添加Logo

数据结构:Map和Set(2):相关OJ题目

目录 136. 只出现一次的数字 - 力扣(LeetCode) 771. 宝石与石头 - 力扣(LeetCode) 旧键盘 (20)__牛客网 (nowcoder.com) 138. 随机链表的复制 - 力扣(LeetCode) 692. 前K个高频单词 - 力扣&#xff08…

linux_day02

1、链接:LN 一个点表示当前工作目录,两个点表示上一层工作目录; 目录的本质:文件(该文件储存目录项,以链表的形式链接,每个结点都是目录项,创建文件相当于把目录项添加到链表中&…

【Unity之UI编程】编写一个面板交互界面需要注意的细节

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:Uni…

devops完整搭建教程(gitlab、jenkins、harbor、docker)

devops完整搭建教程(gitlab、jenkins、harbor、docker) 文章目录 devops完整搭建教程(gitlab、jenkins、harbor、docker)1.简介:2.工作流程:3.优缺点4.环境说明5.部署前准备工作5.1.所有主机永久关闭防火墙…

[PHP]Kodexplorer可道云 v4.47

KodExplorer可道云,原名芒果云,是基于Web技术的私有云和在线文件管理系统,由上海岱牧网络有限公司开发,发布于2012年6月。致力于为用户提供安全可控、可靠易用、高扩展性的私有云解决方案。 用户只需通过简单环境搭建,…