Skip to content

Commit

Permalink
support forecast_fitted_values in distributed (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Dec 19, 2023
1 parent 89f2e02 commit 27eb69b
Show file tree
Hide file tree
Showing 11 changed files with 754 additions and 397 deletions.
4 changes: 4 additions & 0 deletions action_files/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os

import numpy as np
import pandas as pd
import pytest

from statsforecast.utils import generate_series

os.environ['NIXTLA_ID_AS_COL'] = '1'


@pytest.fixture
def n_series():
Expand Down
9 changes: 8 additions & 1 deletion action_files/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dask.dataframe as dd
import pytest

from .utils import pipeline, pipeline_with_level
from .utils import pipeline, pipeline_with_level, pipeline_fitted


def to_distributed(df):
Expand All @@ -21,3 +21,10 @@ def test_dask_flow(horizon, sample_data, n_series):

def test_dask_flow_with_level(horizon, sample_data, n_series):
pipeline_with_level(*sample_data, n_series, horizon)

@pytest.mark.parametrize('use_x', [True, False])
def test_dask_flow_with_fitted(horizon, use_x, sample_data):
series, X_df = sample_data
if not use_x:
X_df = None
pipeline_fitted(series, X_df, horizon)
9 changes: 8 additions & 1 deletion action_files/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import ray

from .utils import pipeline, pipeline_with_level
from .utils import pipeline, pipeline_with_level, pipeline_fitted


def to_distributed(df):
Expand All @@ -21,3 +21,10 @@ def test_ray_flow(horizon, sample_data, n_series):
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python >= 3.8")
def test_ray_flow_with_level(horizon, sample_data, n_series):
pipeline_with_level(*sample_data, n_series, horizon)

@pytest.mark.parametrize('use_x', [True, False])
def test_ray_flow_with_fitted(horizon, use_x, sample_data):
series, X_df = sample_data
if not use_x:
X_df = None
pipeline_fitted(series, X_df, horizon)
9 changes: 8 additions & 1 deletion action_files/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pyspark.sql import SparkSession

from .utils import pipeline, pipeline_with_level
from .utils import pipeline, pipeline_with_level, pipeline_fitted


@pytest.fixture
Expand All @@ -25,3 +25,10 @@ def test_spark_flow(horizon, sample_data, n_series):
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python >= 3.8")
def test_spark_flow_with_level(horizon, sample_data, n_series):
pipeline_with_level(*sample_data, n_series, horizon)

@pytest.mark.parametrize('use_x', [True, False])
def test_spark_flow_with_fitted(horizon, use_x, sample_data):
series, X_df = sample_data
if not use_x:
X_df = None
pipeline_fitted(series, X_df, horizon)
49 changes: 33 additions & 16 deletions action_files/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import fugue.api as fa
import pandas as pd

from statsforecast.core import StatsForecast
from statsforecast.models import (
ADIDA,
Expand All @@ -23,24 +25,24 @@

def pipeline(series, X_df, n_series, horizon, id_col='unique_id', time_col='ds', target_col='y'):
models = [
ADIDA(),
ADIDA(),
AutoARIMA(season_length=7),
ARIMA(season_length=7, order=(0, 1, 2)),
CrostonClassic(),
CrostonClassic(),
CrostonOptimized(),
CrostonSBA(),
CrostonSBA(),
AutoETS(season_length=7),
HistoricAverage(),
IMAPA(),
HistoricAverage(),
IMAPA(),
Naive(),
RandomWalkWithDrift(),
SeasonalExponentialSmoothing(season_length=7, alpha=0.1),
SeasonalNaive(season_length=7),
SeasonalWindowAverage(season_length=7, window_size=4),
SimpleExponentialSmoothing(alpha=0.1),
TSB(alpha_d=0.1, alpha_p=0.3),
WindowAverage(window_size=4)
]
RandomWalkWithDrift(),
SeasonalExponentialSmoothing(season_length=7, alpha=0.1),
SeasonalNaive(season_length=7),
SeasonalWindowAverage(season_length=7, window_size=4),
SimpleExponentialSmoothing(alpha=0.1),
TSB(alpha_d=0.1, alpha_p=0.3),
WindowAverage(window_size=4)
]
sf = StatsForecast(
models=models,
freq='D',
Expand All @@ -59,9 +61,7 @@ def pipeline(series, X_df, n_series, horizon, id_col='unique_id', time_col='ds',
assert cv.columns.tolist() == [id_col, time_col, 'cutoff', target_col] + [m.alias for m in models]

def pipeline_with_level(series, X_df, n_series, horizon):
models = [
AutoARIMA(season_length=7),
]
models = [AutoARIMA(season_length=7)]
sf = StatsForecast(
models=models,
freq='D',
Expand All @@ -75,3 +75,20 @@ def pipeline_with_level(series, X_df, n_series, horizon):
cv = fa.as_pandas(sf.cross_validation(df=series, n_windows=n_windows, h=horizon, level=[80]))
assert cv.shape[0] == n_series * n_windows * horizon
assert cv.columns.tolist() == ['unique_id', 'ds', 'cutoff', 'y', 'AutoARIMA', 'AutoARIMA-lo-80', 'AutoARIMA-hi-80']

def pipeline_fitted(series, X_df, horizon):
models = [SeasonalNaive(season_length=7)]
pd_series = fa.as_pandas(series)
pd_X = None if X_df is None else fa.as_pandas(X_df)
sf = StatsForecast(models=models, freq='D')
sf.forecast(df=pd_series, h=horizon, X_df=pd_X, level=[80, 90], fitted=True)
fitted = sf.forecast_fitted_values()
sf.forecast(df=series, h=horizon, X_df=X_df, level=[80, 90], fitted=True)
distributed_fitted = (
fa.as_pandas(sf.forecast_fitted_values())
.sort_values(['unique_id', 'ds'])
.reset_index(drop=True)
[fitted.columns]
.astype(fitted.dtypes) # fugue returns nullable and pyarrow dtypes
)
pd.testing.assert_frame_equal(fitted, distributed_fitted)
11 changes: 9 additions & 2 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1908,8 +1908,8 @@
" )\n",
" assert df is not None\n",
" engine = make_execution_engine(infer_by=[df])\n",
" backend = make_backend(engine)\n",
" return backend.forecast(\n",
" self._backend = make_backend(engine)\n",
" return self._backend.forecast(\n",
" models=self.models,\n",
" fallback_model=self.fallback_model, \n",
" freq=self.freq, \n",
Expand All @@ -1923,6 +1923,13 @@
" time_col=time_col,\n",
" target_col=target_col,\n",
" )\n",
"\n",
" def forecast_fitted_values(self):\n",
" if hasattr(self, '_backend'):\n",
" res = self._backend.forecast_fitted_values()\n",
" else:\n",
" res = super().forecast_fitted_values()\n",
" return res\n",
" \n",
" def cross_validation(\n",
" self,\n",
Expand Down
Loading

0 comments on commit 27eb69b

Please sign in to comment.