lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

Lag-Llama: Towards Foundation Models for Time Series Forecasting
文章内容:
时间序列预测任务,单变量预测单变量,基于Llama大模型,在zero-shot场景下模型表现优异。创新点,引入滞后特征作为协变量来进行预测。

获得不同频率的lag,来自glunoTS库里面的源码

def _make_lags(middle: int, delta: int) -> np.ndarray:"""Create a set of lags around a middle point including +/- delta."""return np.arange(middle - delta, middle + delta + 1).tolist()def get_lags_for_frequency(freq_str: str,lag_ub: int = 1200,num_lags: Optional[int] = None,num_default_lags: int = 7,
) -> List[int]:"""Generates a list of lags that that are appropriate for the given frequencystring.By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].Remaining lags correspond to the same `season` (+/- `delta`) in previous`k` cycles. Here `delta` and `k` are chosen according to the existing code.Parameters----------freq_strFrequency string of the form [multiple][granularity] such as "12H","5min", "1D" etc.lag_ubThe maximum value for a lag.num_lagsMaximum number of lags; by default all generated lags are returned.num_default_lagsThe number of default lags; by default it is 7."""# Lags are target values at the same `season` (+/- delta) but in the# previous cycle.def _make_lags_for_second(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_minute(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_hour(multiple, num_cycles=7):# We use previous ``num_cycles`` days to generate lagsreturn [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)]def _make_lags_for_day(multiple, num_cycles=4, days_in_week=7, days_in_month=30):# We use previous ``num_cycles`` weeks to generate lags# We use the last month (in addition to 4 weeks) to generate lag.return [_make_lags(k * days_in_week // multiple, 1)for k in range(1, num_cycles + 1)] + [_make_lags(days_in_month // multiple, 1)]def _make_lags_for_week(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lags# Additionally, we use previous 4, 8, 12 weeksreturn [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [[4 // multiple, 8 // multiple, 12 // multiple]]def _make_lags_for_month(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lagsreturn [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)]# multiple, granularity = get_granularity(freq_str)offset = to_offset(freq_str)# normalize offset name, so that both `W` and `W-SUN` refer to `W`offset_name = norm_freq_str(offset.name)if offset_name == "A":lags = []elif offset_name == "Q":assert (offset.n == 1), "Only multiple 1 is supported for quarterly. Use x month instead."lags = _make_lags_for_month(offset.n * 3.0)elif offset_name == "M":lags = _make_lags_for_month(offset.n)elif offset_name == "W":lags = _make_lags_for_week(offset.n)elif offset_name == "D":lags = _make_lags_for_day(offset.n) + _make_lags_for_week(offset.n / 7.0)elif offset_name == "B":lags = _make_lags_for_day(offset.n, days_in_week=5, days_in_month=22) + _make_lags_for_week(offset.n / 5.0)elif offset_name == "H":lags = (_make_lags_for_hour(offset.n)+ _make_lags_for_day(offset.n / 24)+ _make_lags_for_week(offset.n / (24 * 7)))# minuteselif offset_name == "T":lags = (_make_lags_for_minute(offset.n)+ _make_lags_for_hour(offset.n / 60)+ _make_lags_for_day(offset.n / (60 * 24))+ _make_lags_for_week(offset.n / (60 * 24 * 7)))# secondelif offset_name == "S":lags = (_make_lags_for_second(offset.n)+ _make_lags_for_minute(offset.n / 60)+ _make_lags_for_hour(offset.n / (60 * 60)))else:raise Exception("invalid frequency")# flatten lags list and filterlags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub]lags = list(range(1, num_default_lags + 1)) + sorted(list(set(lags)))return lags[:num_lags]

第一部分,生成以middle为中心,以delta为半径的区间[middle-delta,middle+delta] ,这很好理解,比如一周的周期是7天,周期大小在7天附近波动很正常。
在这里插入图片描述

第二部分,对于年月日时分秒这些不同的采样频率,采用不同的具体的函数来确定lags,其中有一个参数num_cycle,进一步利用了周期性,我们考虑间隔1、2、3、…num个周期的时间点之间的联系
在这里插入图片描述
原理类似于这张图,这种周期性的重复性体现在邻近的多个周期上

在这里插入图片描述

lag的用途

计算各类窗口大小

计算采样窗口大小

window_size = estimator.context_length + max(estimator.lags_seq) + estimator.prediction_length# Here we make a window slightly bigger so that instance sampler can sample from each window# An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
window_size = 10 * window_size
# We change ValidationSplitSampler to add min_pastestimator.validation_sampler = ValidationSplitSampler(min_past=estimator.context_length + max(estimator.lags_seq),min_future=estimator.prediction_length,)
  1. 构建静态特征
lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1)#构建一个包含给定序列的滞后值的数组static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1)
expanded_static_feat = unsqueeze_expand(static_feat, dim=-2, size=lags.shape[-2]
)return torch.cat((lags, expanded_static_feat, time_feat), dim=-1), loc, scale

数据集准备过程

对每个数据集采样,window_size=13500,也挺离谱的

 train_data, val_data = [], []for name in TRAIN_DATASET_NAMES:new_data = create_sliding_window_dataset(name, window_size)train_data.append(new_data)new_data = create_sliding_window_dataset(name, window_size, is_train=False)val_data.append(new_data)

采样的具体过程,这里有个问题,样本数量很小的数据集,实际采样窗口大小小于设定的window_size,后续会如何对齐呢?

文章设置单变量预测单变量,所以样本进行了通道分离,同一样本的不同特征被采样为不同的样本

def create_sliding_window_dataset(name, window_size, is_train=True):#划分非重叠的滑动窗口数据集,window_size是对数据集采样的数量,对每个数据集只取前windowsize个样本# Splits each time series into non-overlapping sliding windowsglobal_id = 0freq = get_dataset(name, path=dataset_path).metadata.freq#从数据集中获取时间频率data = ListDataset([], freq=freq)#创建空数据集dataset = get_dataset(name, path=dataset_path).train if is_train else get_dataset(name, path=dataset_path).test#获取原始数据集for x in dataset:windows = []#划分滑动窗口#target:滑动窗口的目标值#start:滑动窗口的起始位置#item_id,唯一标识符#feat_static_cat:静态特征数组for i in range(0, len(x['target']), window_size):windows.append({'target': x['target'][i:i+window_size],'start': x['start'] + i,'item_id': str(global_id),'feat_static_cat': np.array([0]),})global_id += 1data += ListDataset(windows, freq=freq)return data

合并数据集

# Here weights are proportional to the number of time series (=sliding windows)weights = [len(x) for x in train_data]# Here weights are proportinal to the number of individual points in all time series# weights = [sum([len(x["target"]) for x in d]) for d in train_data]train_data = CombinedDataset(train_data, weights=weights)val_data = CombinedDataset(val_data, weights=weights)
class CombinedDataset:def __init__(self, datasets, seed=None, weights=None):self._seed = seedself._datasets = datasetsself._weights = weightsn_datasets = len(datasets)if weights is None:#如果未提供权重,默认平均分配权重self._weights = [1 / n_datasets] * n_datasetsdef __iter__(self):return CombinedDatasetIterator(self._datasets, self._seed, self._weights)def __len__(self):return sum([len(ds) for ds in self._datasets])

网络结构

lagllama

class LagLlamaModel(nn.Module):def __init__(self,max_context_length: int,scaling: str,input_size: int,n_layer: int,n_embd: int,n_head: int,lags_seq: List[int],rope_scaling=None,distr_output=StudentTOutput(),num_parallel_samples: int = 100,) -> None:super().__init__()self.lags_seq = lags_seqconfig = LTSMConfig(n_layer=n_layer,n_embd=n_embd,n_head=n_head,block_size=max_context_length,feature_size=input_size * (len(self.lags_seq)) + 2 * input_size + 6,rope_scaling=rope_scaling,)self.num_parallel_samples = num_parallel_samplesif scaling == "mean":self.scaler = MeanScaler(keepdim=True, dim=1)elif scaling == "std":self.scaler = StdScaler(keepdim=True, dim=1)else:self.scaler = NOPScaler(keepdim=True, dim=1)self.distr_output = distr_outputself.param_proj = self.distr_output.get_args_proj(config.n_embd)self.transformer = nn.ModuleDict(dict(wte=nn.Linear(config.feature_size, config.n_embd),h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),ln_f=RMSNorm(config.n_embd),))

主要是transformer里面首先是一个线性层,然后加了n_layer个Block,最后是RMSNorm,接下来解析Block的代码

在这里插入图片描述

Block

class Block(nn.Module):def __init__(self, config: LTSMConfig) -> None:super().__init__()self.rms_1 = RMSNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.rms_2 = RMSNorm(config.n_embd)self.mlp = MLP(config)self.y_cache = Nonedef forward(self, x: torch.Tensor, is_test: bool) -> torch.Tensor:if is_test and self.y_cache is not None:# Only use the most recent one, rest is in cachex = x[:, -1:]x = x + self.attn(self.rms_1(x), is_test)y = x + self.mlp(self.rms_2(x))if is_test:if self.y_cache is None:self.y_cache = y  # Build cacheelse:self.y_cache = torch.cat([self.y_cache, y], dim=1)[:, 1:]  # Update cachereturn y

代码看到这里不太想继续看了,太多glunoTS库里面的函数了,我完全不熟悉这个库,看起来太痛苦了,还有很多的困惑,最大的困惑就是数据是怎么对齐的,怎么输入到Llama里面的,慢慢看吧

其他

来源
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/226842.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言-位操作符练习题

文章目录 前言一、n&(n-1)的运用场景(n为整数)二、&1 和 >>的应用场景总结 前言 本篇文章介绍利用c语言的位操作符解决一些练习题&#xff0c;目的是掌握各个位操作符的使用和应用场景。 表1.1为c语言中的位操作符 操作符含义&按位与|按位或^按位异或~按位…

Python爬虫教程30:Selenium网页元素,定位的8种方法!

Selenium可以驱动浏览器&#xff0c;完成各种网页浏览器的模拟操作&#xff0c;比如模拟点击等。要想操作一个元素&#xff0c;首先应该识别这个元素。人有各种的特征&#xff08;属性&#xff09;&#xff0c;我们可以通过其特征找到人&#xff0c;如通过身份证号、姓名、家庭…

NFC物联网智能购物车设计方案

智能购物车是综合利用计算机网络、射频识别技术、数据库技术、单片机于一体的设备具有先进性、便于管理性、经济性、普适性。基于NFC (Near Field Communication&#xff0c;近场通信)技术的智能购物车&#xff0c;能够大幅缩短结账排队时间&#xff0c;实现“无感支付”。NFC是…

深入浅出理解转置卷积Conv2DTranspose

温故而知新&#xff0c;可以为师矣&#xff01; 一、参考资料 论文&#xff1a;A guide to convolution arithmetic for deep learning github源码&#xff1a;Convolution arithmetic bilibili视频&#xff1a;转置卷积&#xff08;transposed convolution&#xff09; 转置…

【Linux】深挖进程地址空间

> 作者简介&#xff1a;დ旧言~&#xff0c;目前大二&#xff0c;现在学习Java&#xff0c;c&#xff0c;c&#xff0c;Python等 > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟悉【Linux】进程地址空间 > 毒鸡汤&#xff…

git 常用命令总结

git 工作原理图&#xff1a; git 常用命令及解释: 命令解释例子git init在当前目录初始化一个新的 Git 仓库。git initgit clone <repository>克隆一个远程仓库到本地。git clone https://github.com/example/repository.gitgit add <file>将文件的变化添加到暂存…

MongoDB文档操作

3.3 文档操作 3.1 文档介绍 文档的数据结构和 JSON 基本一样。 所有存储在集合中的数据都是 BSON 格式。 BSON 是一种类似 JSON 的二进制形式的存储格式&#xff0c;是 Binary JSON 的简称。 文档是一组键值(key-value)对(即 BSON)&#xff0c;一个简单的文档例子如下&…

Ubuntu安装K8S(1.28版本,基于containrd)

原文网址&#xff1a;Ubuntu安装K8S(1.28版本&#xff0c;基于containrd&#xff09;-CSDN博客 简介 本文介绍Ubuntu安装K8S的方法。 官网文档&#xff1a;这里 1.安装K8S 1.让apt支持SSL传输 sudo apt-get update sudo apt-get -y install apt-transport-https ca-certi…

web三层架构

目录 1.什么是三层架构 2.运用三层架构的目的 2.1规范代码 2.2解耦 2.3代码的复用和劳动成本的减少 3.各个层次的任务 3.1web层&#xff08;表现层) 3.2service 层(业务逻辑层) 3.3dao 持久层(数据访问层) 4.结合mybatis简单实例演示 1.什么是三层架构 三层架构就是把…

UG装配设计概念

装配的概念&#xff1a;简单说就是将多个零件按照要求组装的过程就叫装配 装配设计的优势&#xff1a; 1、预见产品设计的不足&#xff0c;特别是多零件的配合 2、便于团队协作 3、方便数据管理 4、优化装配工艺 装配设计的两种方法&#xff1a; 1、自下而上&#xff08;自…

【开源】基于Vue+SpringBoot的贫困地区人口信息管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 人口信息管理模块2.2 精准扶贫管理模块2.3 特殊群体管理模块2.4 案件信息管理模块2.5 物资补助模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 人口表3.2.2 扶贫表3.2.3 特殊群体表3.2.4 案件表3.2.5 物资补助表 四…

毫米波雷达:从 3D 走向 4D

1 毫米波雷达已广泛应用于汽车 ADAS 系统 汽车智能驾驶需要感知层、决策层、执行层三大核心系统的高效配合&#xff0c;其中感知层通过传感器探知周围的环境。汽车智能驾驶感知层将真实世界的视觉、物理、事件等信息转变成数字信号&#xff0c;为车辆了解周边环境、制定驾驶操…

恶意软件分析沙箱在网络安全策略中处于什么位置?

恶意软件分析沙箱提供了一种全面的恶意软件分析方法&#xff0c;包括静态和动态技术。这种全面的评估可以更全面地了解恶意软件的功能和潜在影响。然而&#xff0c;许多组织在确定在其安全基础设施中实施沙箱的最有效方法方面面临挑战。让我们看一下可以有效利用沙盒解决方案的…

pytest pytest-emoji通过表情包展示执行状态

pytest-emoji 是一个用于在 Pytest 测试运行期间显示 emoji 表情的插件。它可以为测试结果添加一些有趣的表情符号&#xff0c;以增加测试报告的可读性和趣味性。 使用 pytest-emoji 插件非常简单&#xff0c;只需按照以下步骤进行操作&#xff1a; 首先&#xff0c;确保已经安…

Golang 链表的基础知识

文章目录 链表链表基础知识部分链表的存储方式链表的定义链表的操作性能分析相关leetcode题目 链表 更多有关于go链表的内容可以见这篇文章链表的创建和读取 链表基础知识部分 什么是链表&#xff0c;链表是一种通过指针串联在一起的线性结构&#xff0c;每一个节点由两部分…

docker 安装可视化工具 Protainer 以及 汉化

一、创建保存数据的卷 安装网址&#xff1a;Install Portainer BE with Docker on Linux - Portainer Documentation docker pull portainer/portainer二、根据portainer镜像创建容器 docker run -d -p 8000:8000 -p 9000:9000\ --name portainer --restartalways \ -v /var/r…

大数据技术发展史

今天我们常说的大数据技术&#xff0c;其实起源于Google在2004年前后发表的三篇论文&#xff0c;也就是我们经常听到的“三驾马车”&#xff0c;分别是分布式文件系统GFS、大数据分布式计算框架MapReduce和NoSQL数据库系统BigTable。 你知道&#xff0c;搜索引擎主要就做两件事…

Vue3使用的Compostion Api和Vue2使用的Options Api有什么不同?

我们介绍Compostion Api和Options Api的区别之前&#xff0c;先来说一下为什么会推出来Composition Api&#xff0c;解决了什么问题&#xff1f; Vue2开发项目使用Options Api存在的问题 代码的可读性和维护性随着组件的变大业务的增多而变得差代码的共享和重用性存在缺点不支…

electron 菜单栏打开指定url页面菜单实现方法

electron 菜单栏打开指定url页面菜单 可以是本地URL也可以是远程的URL 自动判断跳转 以下代码可以在主进程main.js里面也可以是在独立的模块文件里面 const { BrowserWindow } require(electron);//定义窗口加载URL export const winURL process.env.NODE_ENV development …

WEB 3D技术 three.js 色彩空间讲解

上文 WEB 3D技术 three.js 设置环境贴图 高光贴图 场景设置 光照贴图 我们讲了基础材质的各种纹理 但是 我们的图片 到了界面场景中 好像绿的程度有点不太一样了 这里的话 涉及到我们的色彩空间 他有两种 一种是线性的 一种是 sRGB类型的 线性呢 就是根据光照强度 去均匀分…