鍵.png)
使用這些基本 REST API 最佳實(shí)踐構(gòu)建出色的 API
在原論文《Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting》中,Lag-Llama作為單變量概率預(yù)測(cè)的通用大模型提出。
在本文中,我們將探討Lag-Llama的架構(gòu)、功能以及訓(xùn)練方式。還會(huì)通過代碼將lagllama應(yīng)用于一個(gè)預(yù)測(cè)項(xiàng)目中,并將其與其他深度學(xué)習(xí)方法Temporal Fusion Transformer (TFT) 和DeepAR進(jìn)行性能比較。
lagllama是為單變量概率預(yù)測(cè)而構(gòu)建的,它使用不依賴于頻率的通用方法來標(biāo)記時(shí)間序列數(shù)據(jù)。這樣模型可以很好地泛化到不可見的頻率。
laglllama的標(biāo)記策略是使用一組指定的滯后特征。
它將從這個(gè)列表中為給定的數(shù)據(jù)集選擇所有合適的頻率: 季度、月、周、天、小時(shí)、秒
也就是說,如果以每日頻率提供數(shù)據(jù)集,lag – llama將嘗試使用每日滯后(t-1),每周滯后(t-7),每月滯后(t-30)等構(gòu)建特征。
策略如下圖所示。
從上圖中,我們還可以看到模型構(gòu)建了其他靜態(tài)協(xié)變量,例如秒/分、小時(shí)/天等等,直到季度/年。雖然這可以很好地推廣到所有類型的時(shí)間序列,但它有一個(gè)致命的缺點(diǎn):由于固定的滯后指數(shù)列表,輸入令牌可能會(huì)變得非常大。
例如,查看每小時(shí)數(shù)據(jù)的每月頻率需要730個(gè)時(shí)間步。這意味著除了所有靜態(tài)協(xié)變量之外,輸入令牌的長度至少為730。
Lag-Llama是一個(gè)基于transformer的純解碼器模型,其靈感來自大型語言模型LLaMA的體系結(jié)構(gòu)。它利用Transformer體系結(jié)構(gòu)來解析輸入token,并將它們映射到具有置信區(qū)間的未來預(yù)測(cè)。
從圖中可以看到輸入標(biāo)記是滯后時(shí)間步長和靜態(tài)協(xié)變量的拼接。輸入序列通過線性投影層將特征映射到解碼器內(nèi)部注意力模塊的隱藏維度。另外就是在最后的輸出,序列被發(fā)送到一個(gè)分布頭負(fù)責(zé)輸出一個(gè)概率分布。
在推理過程中,輸入序列生成下一個(gè)時(shí)間點(diǎn)的分布。然后通過自回歸,模型逐個(gè)生成剩余的預(yù)測(cè)序列,直到達(dá)到設(shè)置的長度。
生成預(yù)測(cè)的自回歸過程有效地允許模型為其預(yù)測(cè)生成不確定性區(qū)間。但是這里的問題就是如果序列很長,自回歸的方式會(huì)將錯(cuò)誤擴(kuò)大。
作為一個(gè)基礎(chǔ)模型,Lag-Llama顯然是在大量的時(shí)間序列數(shù)據(jù)語料庫上訓(xùn)練的,因此該模型可以很好地泛化未見過的時(shí)間序列并進(jìn)行零樣本預(yù)測(cè)。
論文中說:Lag-Llama在來自不同領(lǐng)域的27個(gè)時(shí)間序列數(shù)據(jù)集上進(jìn)行了訓(xùn)練,如能源、交通、經(jīng)濟(jì)等。
數(shù)據(jù)包含7965個(gè)單變量時(shí)間序列,總計(jì)約3.52億個(gè)令牌。
所有數(shù)據(jù)集都是開源的,包括ethth, Exchange和Weather等。
因?yàn)榇a已經(jīng)開源,所以我們可以直接測(cè)試,我們首先使用Lag-Llama的零樣本預(yù)測(cè)能力,并將其性能與特定數(shù)據(jù)模型(如TFT和DeepAR)進(jìn)行比較。
Lag-Llama的實(shí)現(xiàn)是建立在GluonTS之上的,所以我們還需要安裝這個(gè)庫。實(shí)驗(yàn)使用了澳大利亞電力需求數(shù)據(jù)集,該數(shù)據(jù)集包含五個(gè)單變量時(shí)間序列,以半小時(shí)的頻率跟蹤能源需求。
這里有個(gè)說明:Lag-Llama目前的實(shí)現(xiàn)是初期階段。并且存還在積極開發(fā)中,后面可能還會(huì)有很大的調(diào)整,因?yàn)槟壳斑€沒加入微調(diào)的功能。
!git clone https://github.com/time-series-foundation-models/lag-llama/
cd lag-llama
pip install -r requirements.txt --quiet
然后需要我們從HuggingFace下載模型的權(quán)重。
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import torch
from itertools import islice
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator
可以直接從GluonTS加載數(shù)據(jù)集。
dataset = get_dataset("australian_electricity_demand")
backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length
context_length = 3 * prediction_length
簡單地初始化模型并使用LagLlamaEstimator對(duì)象。
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt",
prediction_length=prediction_length,
context_length=context_length,
input_size=estimator_args["input_size"],
n_layer=estimator_args["n_layer"],
n_embd_per_head=estimator_args["n_embd_per_head"],
n_head=estimator_args["n_head"],
scaling=estimator_args["scaling"],
time_feat=estimator_args["time_feat"])
lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)
使用make_evaluation_predictions函數(shù)生成零樣本的預(yù)測(cè)。
forecast_it, ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=predictor)
這個(gè)函數(shù)返回生成器。我們需要把它們轉(zhuǎn)換成列表。
forecasts = list(forecast_it)
tss = list(ts_it)
GluonTS可以使用Evaluator對(duì)象方便地計(jì)算不同的性能指標(biāo)。
evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
RMSE為481.57。
我們還可以隨意地將預(yù)測(cè)可視化。
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})
for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4):
ax = plt.subplot(2, 2, idx+1)
plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target")
forecast.plot( color='g')
plt.xticks(rotation=60)
ax.xaxis.set_major_formatter(date_formater)
ax.set_title(forecast.item_id)
plt.gcf().tight_layout()
plt.legend()
plt.show()
上圖可以看到模型對(duì)數(shù)據(jù)做出了合理的預(yù)測(cè),盡管它在第四個(gè)序列(圖的右下角)上確實(shí)存在問題。
另外由于 Lag-Llama實(shí)現(xiàn)了概率預(yù)測(cè),可以得到預(yù)測(cè)的不確定性區(qū)間。
我們?cè)跀?shù)據(jù)集上訓(xùn)練TFT和DeepAR模型,看看它們是否能表現(xiàn)得更好。
為了節(jié)省時(shí)間,我們將訓(xùn)練設(shè)置為5個(gè)epoch。
from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator
tft_estimator = TemporalFusionTransformerEstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
deepar_estimator = DeepAREstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
訓(xùn)練過程。
tft_predictor = tft_estimator.train(dataset.train)
deepar_predictor = deepar_estimator.train(dataset.train)
訓(xùn)練完成后,生成預(yù)測(cè)并計(jì)算RMSE。
tft_forecast_it, tft_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=tft_predictor)
deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=deepar_predictor)
tft_forecasts = list(tft_forecast_it)
tft_tss = list(tft_ts_it)
deepar_forecasts = list(deepar_forecast_it)
deepar_tss = list(deepar_ts_it)
# Get evaluation metrics
tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts))
deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))
下表突出顯示了性能最好的模型。
可以看到只訓(xùn)練了5個(gè)epoch這兩個(gè)模型都取得了比Lag-Llama更好的結(jié)果,TFT是目前表現(xiàn)最好的模型,DeepAR的表現(xiàn)也優(yōu)于laglama。
雖然laglllama的表現(xiàn)似乎不盡如人意,但該模型沒有經(jīng)過微調(diào)(零樣本學(xué)習(xí)本身就比較困難些)。
在嘗試了TimeGPT和Lag-Llama之后,Lag-Llama算是構(gòu)建開源預(yù)測(cè)模型的第一步,但與TimeGPT相比,它在功能方面存在不足。TimeGPT可以處理多變量時(shí)間序列、不規(guī)則時(shí)間戳,并實(shí)現(xiàn)共形預(yù)測(cè),與使用laglama等固定分布相比,這是一種更穩(wěn)健的量化不確定性的方式。
laglllama是一個(gè)開源的基礎(chǔ)模型,只用于單變量概率預(yù)測(cè),性能表現(xiàn)也比較有限。相信在不久的將來會(huì)看到更多的開源預(yù)測(cè)模型出現(xiàn),他們的表現(xiàn)可能會(huì)得到改善,這代表了該領(lǐng)域的一個(gè)重大轉(zhuǎn)變。
文章轉(zhuǎn)自微信公眾號(hào)@算法進(jìn)階
對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力
一鍵對(duì)比試用API 限時(shí)免費(fèi)