Pyraformer复现心得

Pyraformer复现心得

引用

Liu, Shizhan, et al. “Pyraformer: Low-complexity pyramidal attention for long-range time series modeling and forecasting.” International conference on learning representations. 2021.

代码部分

def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]#B,dmodel*3dec_out = self.projection(enc_out).view(enc_out.size(0), self.pred_len, -1)#B,pre,Nreturn dec_out

预测部分就这么长

x_dec, x_mark_dec, mask=None都没用到

enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
#B,dmodel*3
  • 直接进入encoder
def forward(self, x_enc, x_mark_enc):seq_enc = self.enc_embedding(x_enc, x_mark_enc)
  • 重构了encoder和decoder,跟transformer的很不一样
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
  • embedding方法跟former一样
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)

用pyra的方式获取pam掩码

def get_mask(input_size, window_size, inner_size):"""Get the attention mask of PAM-Naive"""# Get the size of all layersall_size = []all_size.append(input_size)for i in range(len(window_size)):layer_size = math.floor(all_size[i] / window_size[i])all_size.append(layer_size)seq_length = sum(all_size)mask = torch.zeros(seq_length, seq_length)# get intra-scale maskinner_window = inner_size // 2for layer_idx in range(len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = max(i - inner_window, start)right_side = min(i + inner_window + 1, start + all_size[layer_idx])mask[i, left_side:right_side] = 1# get inter-scale maskfor layer_idx in range(1, len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = (start - all_size[layer_idx - 1]) + \(i - start) * window_size[layer_idx - 1]if i == (start + all_size[layer_idx] - 1):right_side = startelse:right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]mask[i, left_side:right_side] = 1mask[left_side:right_side, i] = 1mask = (1 - mask).bool()return mask, all_size

接着进入卷积层

seq_enc = self.conv_layers(seq_enc)

先构建CSCM卷积

class Bottleneck_Construct(nn.Module):"""Bottleneck convolution CSCM"""
temp_input = self.down(enc_input).permute(0, 2, 1)
all_inputs = []
self.down = Linear(d_model, d_inner)

下采样

for i in range(len(self.conv_layers)):temp_input = self.conv_layers[i](temp_input)all_inputs.append(temp_input)

堆叠很多次卷积,这个跟former是一样的

class ConvLayer(nn.Module):def __init__(self, c_in, window_size):super(ConvLayer, self).__init__()self.downConv = nn.Conv1d(in_channels=c_in,out_channels=c_in,kernel_size=window_size,stride=window_size)self.norm = nn.BatchNorm1d(c_in)self.activation = nn.ELU()def forward(self, x):x = self.downConv(x)x = self.norm(x)x = self.activation(x)return x

将N次卷积的结果拼接起来

all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)#
all_inputs = self.up(all_inputs)
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
self.up = Linear(d_inner, d_model)
all_inputs = self.norm(all_inputs)
return all_inputs
self.norm = nn.LayerNorm(d_model)

之后在跟原始输入拼接起来

  • 卷积layer完了之后是encoderlayer
def forward(self, enc_input, slf_attn_mask=None):attn_mask = RegularMask(slf_attn_mask)
enc_output, _ = self.slf_attn(enc_input, enc_input, enc_input, attn_mask=attn_mask)

进到encoder里面,到了熟悉的former框架

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):#后面俩参数应该是作者指定的B, L, _ = queries.shape#B,seq,dmodel_, S, _ = keys.shapeH = self.n_heads
#其实L和S是一个数queries = self.query_projection(queries).view(B, L, H, -1)#B, L, H, dmodel/hkeys = self.key_projection(keys).view(B, S, H, -1)#一样的计算方法values = self.value_projection(values).view(B, S, H, -1)#H 表示头的数量-1 表示自动计算该维度
  • encoder的注意力用的fullattention。并且用到了掩码

回到pyra的encoder

self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, x):residual = xif self.normalize_before:x = self.layer_norm(x)x = F.gelu(self.w_1(x))x = self.dropout(x)x = self.w_2(x)x = self.dropout(x)x = x + residualif not self.normalize_before:x = self.layer_norm(x)return x
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
#B,seq,3,dmodel
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
#B,seq+pred,dmodel
all_enc = torch.gather(seq_enc, 1, indexes)
##B,seq+pred,dmodel
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
#B,seq,dmodel*3
return seq_enc

总结

x_dec, x_mark_dec, mask=None都没用到

  • 直接进入encoder

重构了encoder和decoder,跟transformer的很不一样

embedding方法跟former一样

encoder的注意力用的fullattention,并且用到了掩码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/465226.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DAY21|二叉树Part08|LeetCode: 669. 修剪二叉搜索树、108.将有序数组转换为二叉搜索树、538.把二叉搜索树转换为累加树

目录 LeetCode: 669. 修剪二叉搜索树 基本思路 C代码 LeetCode: 108.将有序数组转换为二叉搜索树 基本思路 C代码 LeetCode: 538.把二叉搜索树转换为累加树 基本思路 C代码 LeetCode: 669. 修剪二叉搜索树 力扣代码链接 文字讲解:LeetCode: 669. 修剪二叉搜…

ubuntu20.04安装ros与rosdep

目录 前置配置 配置apt清华源 配置ros软件源 添加ros安装源(中科大软件源) 设置秘钥 更新源 ros安装 安装ros 初始化 rosdep 更新 rosdep 设置环境变量 安装 rosinstall 安装验证 启动海龟仿真器 操控海龟仿真器 rosdep安装更新 安装 使用…

高亚科技签约酸动力,助力研发管理数字化升级

近日,中国企业管理软件资深服务商高亚科技与广东酸动力生物科技有限公司(以下简称“酸动力”)正式签署合作协议。借助高亚科技的8Manage PM项目管理软件,酸动力将进一步优化项目过程跟踪与节点监控,提升研发成果的高效…

CSRF与SSRF

csrf(跨站请求伪造)的原理: csrf全称是跨站请求伪造(cross-site request forgery),也被称为one-click attack 或者 session riding scrf攻击利用网站对于用户网页浏览器的信任,劫持用户当前已登录的web应用程序,去执行分用户本意的操作。 利…

享元模式-实现大颗粒度对象缓存机制

详解 享元模式是一种结构型设计模式,其主要目的是通过共享尽可能多的相同部分来有效地支持大量细粒度的对象。它通过将对象的属性分为内在属性(可以共享、不随环境变化的部分)和外在属性(根据场景变化、不能共享的部分&#xff0…

HTML 基础标签——结构化标签<html>、<head>、<body>

文章目录 1. <html> 标签2. <head> 标签3. <body> 标签4. <div> 标签5. <span> 标签小结 在 HTML 文档中&#xff0c;使用特定的结构标签可以有效地组织和管理网页内容。这些标签不仅有助于浏览器正确解析和渲染页面&#xff0c;还能提高网页的可…

新华三H3CNE网络工程师认证—VLAN的配置

VLAN&#xff08;虚拟局域网&#xff09;是一种在逻辑上划分网络的技术&#xff0c;它可以将一个物理网络分割成多个虚拟网络&#xff0c;从而实现不同组的设备之间的隔离。在配置VLAN时&#xff0c;通常涉及到三种端口类型&#xff1a;Access、Trunk和Hybrid。Access端口用于连…

R语言*号标识显著性差异判断组间差异是否具有统计意义

前言 该R代码用于对Iris数据集进行多组比较分析&#xff0c;探讨不同鸢尾花品种在不同测量变量&#xff08;花萼和花瓣长度与宽度&#xff09;上的显著性差异。通过将数据转换为长格式&#xff0c;并利用ANOVA和Tukey检验&#xff0c;代码生成了不同品种间的显著性标记&#x…

手边酒店多商户版V2源码独立部署_博纳软云

新版采用laraveluniapp开发&#xff0c;为更多平台小程序开发提供坚实可靠的底层架构基础。后台UI全部重写&#xff0c;兼容手机端管理。 全新架构、会员卡、钟点房、商城、点餐、商户独立管理

Multi Agents协作机制设计及实践

01 多智能体协作机制的背景概述 在前述博客中&#xff0c;我们利用LangChain、AutoGen等开发框架构建了一个数据多智能体的平台&#xff0c;并使用了LangChain的Multi-Agents框架。然而&#xff0c;在实施过程中&#xff0c;我们发现现有的框架存在一些局限性&#xff0c;这些…

ReactPress—基于React的免费开源博客CMS内容管理系统

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议&#xff0c;感谢Star。 ![ReactPress](https://i-blog.csdnimg.cn/direct/0720f155edaa4eadba796f4d96d394d7.png#pic_center ReactPress 是使用React开发的开源发布平台&…

如何在一个 Docker 容器中运行多个进程 ?

在容器化的世界里&#xff0c;Docker 彻底改变了开发人员构建、发布和运行应用程序的方式。Docker 容器封装了运行应用程序所需的所有依赖项&#xff0c;使其易于跨不同环境一致地部署。然而&#xff0c;在单个 Docker 容器中管理多个进程可能具有挑战性&#xff0c;这就是 Sup…

【JavaEE初阶 — 多线程】线程安全问题 & synchronized

目录 1. 什么是线程安全问题 (1) 观察线程不安全 (2) 线程安全的概念 2. 造成线程安全的原因 (1)线程调度的随机性 问题描述 解决方案 (2)修改共享数据&#xff06;原子性问题 问题描述 解决方案 3.synchronized 关键字 1. synchronized 的特性 (1) …

产品经理的重要性

一直觉得产品经理很重要&#xff0c;这几年写了好几篇和产品经理相关的思考。2020年写过对产品经理的一些思考的文章&#xff0c;2021年&#xff0c;写了一篇对如何分析项目的思考&#xff0c;2024年写了如何与PM探讨项目。 今天还想再写一篇&#xff0c;主要是最近很有感慨。…

Hunyuan-Large:推动AI技术进步的下一代语言模型

腾讯近期推出了基于Transformer架构的混合专家&#xff08;MoE&#xff09;模型——Hunyuan-Large&#xff08;Hunyuan-MoE-A52B&#xff09;。该模型目前是业界开源的最大MoE模型之一&#xff0c;拥有3890亿总参数和520亿激活参数&#xff0c;展示了极强的计算能力和资源优化优…

【Linux系列】利用 CURL 发送 POST 请求

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

通义灵码实操—飞机大战游戏

通义灵码实操—飞机大战游戏 有没有想象过自己独立编写一个有趣的小游戏。在本实践课程中&#xff0c;你不仅可以实现这个想法&#xff0c;而且还将得到通义灵码智能编程助手的支持与指导。我们将携手步入编程的神奇世界&#xff0c;以一种简洁、高效且具有创造性的方式&#…

lora训练模型 打造个人IP

准备工作 下载秋叶炼丹器整理自己的照片下载底膜 https://rentry.org/lycoris-experiments 实操步骤 解压整合包 lora-scripts,先点击“更新” 训练图片收集 比如要训练一个自己头像的模型&#xff0c;就可以拍一些自己的照片&#xff08;20-50张&#xff0c;最少15张&…

Caffeine 手动策略缓存 put() 方法源码解析

BoundedLocalManualCache put() 方法源码解析 先看一下BoundedLocalManualCache的类图 com.github.benmanes.caffeine.cache.BoundedLocalCache中定义的BoundedLocalManualCache静态内部类。 static class BoundedLocalManualCache<K, V> implements LocalManualCache&…

Spring Boot框架下的教育导师匹配系统

第一章 绪论 1.1 选题背景 如今的信息时代&#xff0c;对信息的共享性&#xff0c;信息的流通性有着较高要求&#xff0c;尽管身边每时每刻都在产生大量信息&#xff0c;这些信息也都会在短时间内得到处理&#xff0c;并迅速传播。因为很多时候&#xff0c;管理层决策需要大量信…