AI时序预测: iTransformer算法代码深度解析

在之前的文章中,我对iTransformer的Paper进行了详细解析,具体文章如下:

文章链接:深度解析iTransformer:维度倒置与高效注意力机制的结合

今天,我将对iTransformer代码进行解析。回顾Paper,我们知道iTransformer通过简单地将注意力机制和前馈网络应用于倒置的维度上。具体而言,单个序列的时间点被嵌入为变量令牌(variate tokens),并利用注意力机制捕捉变量间的相关性;同时,前馈网络被应用于每个变量令牌,以学习非线性表示。

iTransformer 整体架构如下图所示,采用了 Transformer(Vaswani et al., 2017)的仅编码器(encoder-only)架构,包括嵌入层、投影层和 Transformer 块。接下来,我们看看每个模块是如何通过代码实现的。

图片

1. 嵌入层

 


import torch
import torch.nn as nnclass DataEmbedding_inverted(nn.Module):"""该类用于数据嵌入(Embedding),适用于时间序列建模或其他需要将输入转换为高维表示的任务。它通过线性变换将输入数据映射到 `d_model` 维度,并可选地结合额外的时间信息进行处理。参数:- c_in (int): 输入特征的维度(即变量数)。- d_model (int): 目标嵌入维度(即转换后的特征维度)。- embed_type (str, 可选): 嵌入类型,当前代码未使用该参数,默认值为 'fixed'。- freq (str, 可选): 频率信息(如 'h' 代表小时级别),当前代码未使用该参数。- dropout (float, 可选): Dropout 比例,控制神经元随机失活的概率,以防止过拟合。"""def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):super(DataEmbedding_inverted, self).__init__()# 线性映射层:将输入数据从 `c_in` 维度投影到 `d_model` 维度self.value_embedding = nn.Linear(c_in, d_model)# Dropout 层:用于在训练时随机丢弃部分神经元,以增强模型的泛化能力self.dropout = nn.Dropout(p=dropout)def forward(self, x, x_mark):"""前向传播函数,将输入 `x` 及可选的时间标记 `x_mark` 进行处理,返回嵌入后的表示。参数:- x (Tensor): 输入数据,形状为 [Batch, Time, Variate],其中:- Batch: 批量大小- Time: 时间步数- Variate: 变量数(即 `c_in`)- x_mark (Tensor 或 None): 时间标记信息,形状为 [Batch, Time, Extra_Features],如果为 None,则仅使用 `x` 进行嵌入。返回:- Tensor: 经过嵌入和 Dropout 处理后的数据,形状为 [Batch, Variate, d_model]。"""# 交换 `Time` 和 `Variate` 维度,调整形状以适配后续处理x = x.permute(0, 2, 1)  # 变换后形状:[Batch, Variate, Time]if x_mark is None:# 仅使用输入数据 `x` 进行嵌入x = self.value_embedding(x)else:# 如果提供了 `x_mark`,先调整其形状,再与 `x` 拼接后进行嵌入x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], dim=1))# 经过 Dropout 处理后返回最终的嵌入表示return self.dropout(x)

完整文章链接:AI时序预测: iTransformer算法代码深度解析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/9987.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

某盾Blackbox参数参数逆向

以前叫同盾,现在改名了,叫小盾安全,好像不做验证码了

docker中运行的MySQL怎么修改密码

1,进入MySQL容器 docker exec -it 容器名 bash 我运行了 docker ps命令查看。正在运行的容器名称。可以看到MySQL的我起名为db docker exec -it db bash 这样就成功的进入到容器中了。 2,登录MySQL中 mysql -u 用户名 -p 回车 密码 mysql -u root -p roo…

春节期间,景区和酒店如何合理用工?

春节期间,景区和酒店如何合理用工? 春节期间,旅游市场将迎来高峰期。景区与酒店,作为旅游产业链中的两大核心环节,承载着无数游客的欢乐与期待。然而,也隐藏着用工管理的巨大挑战。如何合理安排人力资源&a…

初始化mysql报错cannot open shared object file: No such file or directory

报错展示 我在初始化msyql的时候报错:mysqld: error while loading shared libraries: libaio.so.1: cannot open shared object file: No such file or directory 解读: libaio包的作用是为了支持同步I/O。对于数据库之类的系统特别重要,因此…

C语言------数组从入门到精通

1.一维数组 目标:通过思维导图了解学习一维数组的核心知识点: 1.1定义 使用 类型名 数组名[数组长度]; 定义数组。 // 示例: int arr[5]; 1.2一维数组初始化 数组的初始化可以分为静态初始化和动态初始化两种方式。 它们的主要区别在于初始化的时机和内存分配的方…

Docker/K8S

文章目录 项目地址一、Docker1.1 创建一个Node服务image1.2 volume1.3 网络1.4 docker compose 二、K8S2.1 集群组成2.2 Pod1. 如何使用Pod(1) 运行一个pod(2) 运行多个pod 2.3 pod的生命周期2.4 pod中的容器1. 容器的生命周期2. 生命周期的回调3. 容器重启策略4. 自定义容器启…

【开源免费】基于SpringBoot+Vue.JS公交线路查询系统(JAVA毕业设计)

本文项目编号 T 164 ,文末自助获取源码 \color{red}{T164,文末自助获取源码} T164,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

< OS 有关 > Android 手机 SSH 客户端 app: connectBot

connectBot 开源且功能齐全的SSH客户端,界面简洁,支持证书密钥。 下载量超 500万 方便在 Android 手机上,连接 SSH 服务器,去运行命令。 Fail2ban 12小时内抓获的 IP ~ ~ ~ ~ rootjpn:~# sudo fail2ban-client status sshd Status for the jail: sshd …

中国股市“慢牛”行情的实现路径与展望

在现代经济体系中,股市不仅是企业融资的重要平台,也是投资者财富增值的关键渠道。一个健康、稳定、持续增长的股市,对于推动经济高质量发展、提升国家金融竞争力具有深远意义。近年来,“慢牛”行情成为众多投资者和市场参与者对我…

Linux Samba 低版本漏洞(远程控制)复现与剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 漏洞影响 防御措施 复现过程 结语 前言 在网络安全的复杂生态中,系统漏洞的探索与防范始终是保障数字世界安全稳定运行的关键所在。Linux Samba 作为一款在网络共享服务领域应用极为广泛的软件,其低版本中…

ResNet 残差网络

目录 网络结构 残差块(Residual Block) ResNet网络结构示意图 残差块(Residual Block)细节 基本残差块(ResNet-18/34) Bottleneck残差块(ResNet-50/101/152) 残差连接类型对比 变体网…

【Unity3D】实现横版2D游戏角色二段跳、蹬墙跳、扶墙下滑

目录 一、二段跳、蹬墙跳 二、扶墙下滑 一、二段跳、蹬墙跳 GitHub - prime31/CharacterController2D 下载工程后直接打开demo场景:DemoScene(Unity 2019.4.0f1项目环境) Player物体上的CharacterController2D,Mask添加Wall层…

FPGA 使用 CLOCK_LOW_FANOUT 约束

使用 CLOCK_LOW_FANOUT 约束 您可以使用 CLOCK_LOW_FANOUT 约束在单个时钟区域中包含时钟缓存负载。在由全局时钟缓存直接驱动的时钟网段 上对 CLOCK_LOW_FANOUT 进行设置,而且全局时钟缓存扇出必须低于 2000 个负载。 注释: 当与其他时钟约束配合…

Excel 技巧21 - Excel中整理美化数据实例,Ctrl+T 超级表格(★★★)

本文讲Excel中如何整理美化数据的实例,以及CtrlT 超级表格的常用功能。 目录 1,Excel中整理美化数据 1-1,设置间隔行颜色 1-2,给总销量列设置数据条 1-3,根据总销量设置排序 1-4,加一个销售趋势列 2&…

Leetcode:219

1&#xff0c;题目 2&#xff0c;思路 第一种就是简单的暴力比对当时过年没细想 第二种&#xff1a; 用Map的特性key唯一&#xff0c;把数组的值作为Map的key值我们每加载一个元素都会去判断这个元素在Map里面存在与否如果存在进行第二个判断条件abs(i-j)<k,条件 符合直接…

MySQL(高级特性篇) 14 章——MySQL事务日志

事务有4种特性&#xff1a;原子性、一致性、隔离性和持久性 事务的隔离性由锁机制实现事务的原子性、一致性和持久性由事务的redo日志和undo日志来保证&#xff08;1&#xff09;REDO LOG称为重做日志&#xff0c;用来保证事务的持久性&#xff08;2&#xff09;UNDO LOG称为回…

芯片AI深度实战:进阶篇之vim内verilog实时自定义检视

本文基于Editor Integration | ast-grep&#xff0c;以及coc.nvim&#xff0c;并基于以下verilog parser(my-language.so&#xff0c;文末下载链接), 可以在vim中实时显示自定义的verilog 匹配。效果图如下&#xff1a; 需要的配置如下&#xff1a; 系列文章&#xff1a; 芯片…

C++:多继承习题5

题目内容&#xff1a; 先建立一个Point(点)类&#xff0c;包含数据成员x,y(坐标点)。以它为基类&#xff0c;派生出一个Circle(圆)类&#xff0c;增加数据成员r(半径)&#xff0c;再以Circle类为直接基类&#xff0c;派生出一个Cylinder(圆柱体)类&#xff0c;再增加数据成员h…

基于阿里云百炼大模型Sensevoice-1的语音识别与文本保存工具开发

基于阿里云百炼大模型Sensevoice-1的语音识别与文本保存工具开发 摘要 随着人工智能技术的不断发展&#xff0c;语音识别在会议记录、语音笔记等场景中得到了广泛应用。本文介绍了一个基于Python和阿里云百炼大模型的语音识别与文本保存工具的开发过程。该工具能够高效地识别东…

buu-pwn1_sctf_2016-好久不见29

这个也是栈溢出&#xff0c;不一样的点是&#xff0c;有replace替换&#xff0c;要输入0x3c字符&#xff08;60&#xff09;&#xff0c;Iyou 所以&#xff0c;20个I就行&#xff0c;找后面函数 输出提示信息&#xff0c;要求用户输入关于自己的信息。 使用fgets函数从标准输入…