-
Notifications
You must be signed in to change notification settings - Fork 196
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add lag-llama experiment (#224) y la 🧀✖️☕️ 💅
- Loading branch information
Showing
7 changed files
with
537 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
download_lag_llama_code: | ||
@git clone https://github.com/time-series-foundation-models/lag-llama tempdir | ||
@cp -R tempdir/data/ . | ||
@cp -R tempdir/gluon_utils/ . | ||
@cp -R tempdir/lag_llama/ . | ||
@cp -R tempdir/requirements.txt lag-llama-requirements.txt | ||
@rm -rf tempdir | ||
|
||
download_lag_llama_model: | ||
@huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir ./models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# LagLLama is 40% less accurate than a simple SeasonalNaive and 1000x slower. | ||
|
||
We present a fully reproducible experiment showing that SeasonalNaive significantly outperforms LagLlama, a recently introduced open-source foundational model for time series forecasting (a deep learning architecture pre-trained on time series datasets). Specifically, **SeasonalNaive achieves 42%, 24%, and 16% better performance** in terms of MASE, MAPE, and CRPS respectively, and boasts **a 1,000x speed advantage**. These findings are based on an extensive analysis covering 105,289 unique time series from the M1, M3, M4, and Tourism datasets, which were omitted in the original LagLlama paper. | ||
|
||
# Introduction | ||
|
||
In the field of time series forecasting, recent developments have introduced foundational models such as LagLlama, which utilizes deep learning and extensive data for pretraining, aiming to enhance predictive performance and model complexity. LagLLama is to be praised as one of the first open-source foundational models. However, contrary to expectations, our analysis indicates that the traditional SeasonalNaive model, known for its straightforward approach of extending past seasonal trends into future predictions, outperforms LagLlama in terms of both accuracy and computational efficiency. | ||
|
||
## Empirical Evaluation | ||
|
||
The original paper uses 3,113 time series to assess the model performance. The original paper only reports CRPS and omits point forecast error metrics widely used in academia and industry, e.g. MASE and MAPE. | ||
|
||
Our evaluation encompasses 105,289 unique time series from different datasets, including M1, M3, M4, and Tourism, covering yearly, quarterly, monthly, weekly, daily, and hourly frequencies. This diverse dataset selection allows for a robust assessment of the models across various time series characteristics and forecasting horizons. We also reproduce results for Pedestrian Counts and Weather originally included in the paper/code to show that we are running LagLlama correctly. | ||
|
||
## Results | ||
|
||
The results are summarized in the following table, highlighting the performance metrics of MASE, MAPE, CRPS, and TIME (measured in seconds). The best results are indicated in **bold** for easy reference. | ||
|
||
<img width="965" alt="image" src="https://github.com/Nixtla/nixtla/assets/10517170/5047bce8-b683-4e07-9af3-8c864121a71b"> | ||
|
||
|
||
## Reproducibility | ||
|
||
To ensure the reproducibility of our findings, the experiments were conducted on an AWS g5.4xlarge GPU instance equipped with 16 vCPUs, 64 GiB of RAM, and an NVIDIA A10G Tensor Core GPU (24 GiB). The complete code can be found in this repo. | ||
|
||
### Instructions | ||
|
||
1. Create a python environment using: | ||
``` | ||
mamba env create -f environment.yml | ||
conda activate lag-llama | ||
``` | ||
|
||
2. Add lag-llama code to your environment | ||
|
||
``` | ||
make download_lag_llama_code | ||
``` | ||
|
||
5. Download lag-llama model | ||
|
||
``` | ||
make download_lag_llama_model | ||
``` | ||
|
||
4. Install lag-llama requirements | ||
|
||
``` | ||
pip install -r lag-llama-requirements.txt | ||
``` | ||
|
||
5. Run complete experiments reported in the table | ||
|
||
``` | ||
python -m src.main | ||
``` | ||
|
||
### References | ||
- **Lag-Llama Paper**: [Towards Foundation Models for Probabilistic Time Series Forecasting](https://arxiv.org/abs/2310.08278) | ||
- **SeasonalNaive Implementation**: [GitHub Repository](https://github.com/nixtla/statsforecast/) | ||
- **CRPS Replication Note**: The CRPS performance for `LagLlama` is replicated from the model's publicly available [Colab notebook](https://colab.research.google.com/drive/13HHKYL_HflHBKxDWycXgIUAHSeHRR5eo?usp=sharing), ensuring a fair comparison. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
name: lag-llama | ||
channels: | ||
- conda-forge | ||
- defaults | ||
- anaconda | ||
dependencies: | ||
- jupyterlab | ||
- pip | ||
- python=3.10 | ||
- pip: | ||
- datasetsforecast | ||
- fire | ||
- huggingface_hub[cli] | ||
- neuralforecast | ||
- orjson | ||
- statsforecast | ||
- utilsforecast | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from time import time | ||
from typing import Iterable, List, Tuple | ||
|
||
import fire | ||
import pandas as pd | ||
import torch | ||
from gluonts.dataset import Dataset | ||
from gluonts.model.forecast import Forecast | ||
from gluonts.torch.model.predictor import PyTorchPredictor | ||
from tqdm import tqdm | ||
|
||
from lag_llama.gluon.estimator import LagLlamaEstimator | ||
from src.utils import ExperimentHandler | ||
|
||
|
||
def get_lag_llama_predictor( | ||
prediction_length: int, models_dir: str | ||
) -> PyTorchPredictor: | ||
model_path = f"{models_dir}/lag-llama.ckpt" | ||
map_location = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" | ||
if map_location == "cpu": | ||
raise ValueError("cpu is not supported in lagllama (there is a bug)") | ||
ckpt = torch.load(model_path, map_location=map_location) | ||
estimator_args = ckpt["hyper_parameters"]["model_kwargs"] | ||
# this context length is reported in the paper | ||
context_length = 32 | ||
estimator = LagLlamaEstimator( | ||
ckpt_path=model_path, | ||
prediction_length=prediction_length, | ||
context_length=context_length, | ||
# estimator args | ||
input_size=estimator_args["input_size"], | ||
n_layer=estimator_args["n_layer"], | ||
n_embd_per_head=estimator_args["n_embd_per_head"], | ||
n_head=estimator_args["n_head"], | ||
scaling=estimator_args["scaling"], | ||
time_feat=estimator_args["time_feat"], | ||
) | ||
lightning_module = estimator.create_lightning_module() | ||
transformation = estimator.create_transformation() | ||
predictor = estimator.create_predictor(transformation, lightning_module) | ||
return predictor | ||
|
||
|
||
def gluonts_instance_fcst_to_df( | ||
fcst: Forecast, | ||
quantiles: List[float], | ||
model_name: str, | ||
) -> pd.DataFrame: | ||
point_forecast = fcst.mean | ||
h = len(point_forecast) | ||
dates = pd.date_range( | ||
fcst.start_date.to_timestamp(), | ||
freq=fcst.freq, | ||
periods=h, | ||
) | ||
fcst_df = pd.DataFrame( | ||
{ | ||
"ds": dates, | ||
"unique_id": fcst.item_id, | ||
model_name: point_forecast, | ||
} | ||
) | ||
for q in quantiles: | ||
fcst_df[f"{model_name}-q-{q}"] = fcst.quantile(q) | ||
return fcst_df | ||
|
||
|
||
def gluonts_fcsts_to_df( | ||
fcsts: Iterable[Forecast], | ||
quantiles: List[float], | ||
model_name: str, | ||
) -> pd.DataFrame: | ||
df = [] | ||
for fcst in tqdm(fcsts): | ||
fcst_df = gluonts_instance_fcst_to_df(fcst, quantiles, model_name) | ||
df.append(fcst_df) | ||
return pd.concat(df).reset_index(drop=True) | ||
|
||
|
||
def run_lag_llama( | ||
gluonts_dataset: Dataset, | ||
horizon: int, | ||
quantiles: List[float], | ||
models_dir: str, | ||
) -> Tuple[pd.DataFrame, float, str]: | ||
init_time = time() | ||
predictor = get_lag_llama_predictor(horizon, models_dir) | ||
fcsts = predictor.predict(gluonts_dataset, num_samples=100) | ||
model_name = "LagLlama" | ||
fcsts_df = gluonts_fcsts_to_df( | ||
fcsts, | ||
quantiles=quantiles, | ||
model_name=model_name, | ||
) | ||
total_time = time() - init_time | ||
return fcsts_df, total_time, model_name | ||
|
||
|
||
def main(dataset: str): | ||
exp = ExperimentHandler(dataset) | ||
fcst_df, total_time, model_name = run_lag_llama( | ||
gluonts_dataset=exp.gluonts_train_dataset, | ||
horizon=exp.horizon, | ||
quantiles=exp.quantiles, | ||
models_dir=exp.models_dir, | ||
) | ||
exp._save_results(fcst_df, total_time, model_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import logging | ||
import subprocess | ||
|
||
import pandas as pd | ||
|
||
from src.utils import ExperimentHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
not_included_datasets = [ | ||
"m1_yearly", | ||
"m1_quarterly", | ||
"m1_monthly", | ||
"m3_yearly", | ||
"m3_quarterly", | ||
"m3_monthly", | ||
"m3_other", | ||
"m4_yearly", | ||
"m4_quarterly", | ||
"m4_monthly", | ||
"m4_weekly", | ||
"m4_daily", | ||
"m4_hourly", | ||
"tourism_yearly", | ||
"tourism_quarterly", | ||
"tourism_monthly", | ||
] | ||
|
||
test_paper_datasets = [ | ||
"pedestrian_counts", | ||
"weather", | ||
] | ||
|
||
datasets = { | ||
"not_included": not_included_datasets, | ||
"test_set": test_paper_datasets, | ||
} | ||
|
||
|
||
def evaluate(): | ||
eval_df = [] | ||
prefix_process = ["python", "-m"] | ||
|
||
for name_group, groups in datasets.items(): | ||
for dataset in groups: | ||
logger.info(f"Evaluating {dataset}...") | ||
suffix_process = ["--dataset", dataset] | ||
process = ( | ||
lambda middle_process: prefix_process + middle_process + suffix_process | ||
) | ||
# running statsforecast and lagllama in separated | ||
# processes because gluonts sets multiprocessing context | ||
# see: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/__init__.py | ||
logger.info("Running SeasonalNaive") | ||
subprocess.run(process(["src.statsforecast_pipeline"])) | ||
logger.info("Running LagLLama") | ||
subprocess.run(process(["src.lag_llama_pipeline"])) | ||
logger.info("Running dataset evaluation") | ||
exp = ExperimentHandler(dataset) | ||
eval_dataset_df = exp.evaluate_models(["LagLlama", "SeasonalNaive"]) | ||
eval_dataset_df.insert(0, "paper", name_group) | ||
eval_df.append(eval_dataset_df) | ||
eval_df = pd.concat(eval_df).reset_index(drop=True) | ||
exp.save_dataframe(eval_df, "complete-results.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
evaluate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
from time import time | ||
from typing import List, Tuple | ||
|
||
import fire | ||
import pandas as pd | ||
from statsforecast import StatsForecast | ||
from statsforecast.models import SeasonalNaive | ||
|
||
from src.utils import ExperimentHandler | ||
|
||
|
||
def run_statsforecast( | ||
train_df: pd.DataFrame, | ||
horizon: int, | ||
freq: str, | ||
seasonality: int, | ||
level: List[int], | ||
) -> Tuple[pd.DataFrame, float, str]: | ||
os.environ["NIXTLA_ID_AS_COL"] = "true" | ||
models = [SeasonalNaive(season_length=seasonality)] | ||
init_time = time() | ||
sf = StatsForecast( | ||
models=models, | ||
freq=freq, | ||
n_jobs=-1, | ||
) | ||
fcsts_df = sf.forecast(df=train_df, h=horizon, level=level) | ||
total_time = time() - init_time | ||
model_name = repr(models[0]) | ||
return fcsts_df, total_time, model_name | ||
|
||
|
||
def main(dataset: str): | ||
exp = ExperimentHandler(dataset) | ||
fcst_df, total_time, model_name = run_statsforecast( | ||
train_df=exp.train_df, | ||
horizon=exp.horizon, | ||
freq=exp.freq, | ||
seasonality=exp.seasonality, | ||
level=exp.level, | ||
) | ||
fcst_df = exp._fcst_from_level_to_quantiles(fcst_df, model_name) | ||
exp._save_results(fcst_df, total_time, model_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
Oops, something went wrong.