Skip to content

Commit

Permalink
Added support for additional estimators for multiseries datasets (#4385)
Browse files Browse the repository at this point in the history
* Initial commit

* Updated tests

* Added addtional drop nan test case

* Updated release notes

* Reverted series ID name

* Moved infer feature types

* Added clarifying comments and updated test

* Consolidated code and added additional clarifying comments

* Code cleanup

* Added support for ndarrays for featurizer
  • Loading branch information
christopherbunn authored Jan 31, 2024
1 parent c93b8f2 commit ba6617a
Show file tree
Hide file tree
Showing 24 changed files with 253 additions and 105 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Release Notes
-------------
**Future Releases**
* Enhancements
* Added support for additional estimators for multiseries datasets :pr:`4385`
* Fixes
* Fixed bug in `_downcast_nullable_y` causing woodwork initialization issues :pr:`4369`
* Fixed multiseries prediction interval labels :pr:`4377`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ class CatBoostRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ class DecisionTreeRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ class ElasticNetRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ class ExtraTreesRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ class LightGBMRegressor(Estimator):
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
]
"""[ProblemTypes.REGRESSION]"""
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
]"""

SEED_MIN = 0
SEED_MAX = SEED_BOUNDS.max_bound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class LinearRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(self, fit_intercept=True, n_jobs=-1, random_seed=0, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ class RandomForestRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ class XGBoostRegressor(Estimator):
supported_problem_types = [
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]
"""[
ProblemTypes.REGRESSION,
ProblemTypes.TIME_SERIES_REGRESSION,
ProblemTypes.MULTISERIES_TIME_SERIES_REGRESSION,
]"""

# xgboost supports seeds from -2**31 to 2**31 - 1 inclusive. these limits ensure the random seed generated below
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformer to drop rows specified by row indices."""
import pandas as pd
from woodwork import init_series

from evalml.pipelines.components.transformers import Transformer
Expand Down Expand Up @@ -43,12 +44,24 @@ def transform(self, X, y=None):
y_t = infer_feature_types(y) if y is not None else None

X_t_schema = X_t.ww.schema
y_t_logical = None
y_t_semantic = None
if y_t is not None:
y_t_logical = y_t.ww.logical_type
if isinstance(y_t, pd.DataFrame):
y_t_logical = y_t.ww.logical_types
else:
y_t_logical = y_t.ww.logical_type
y_t_semantic = y_t.ww.semantic_tags

X_t, y_t = drop_rows_with_nans(X_t, y_t)
X_t.ww.init_with_full_schema(X_t_schema)
if y_t is not None:
y_t = init_series(y_t, logical_type=y_t_logical, semantic_tags=y_t_semantic)
if isinstance(y_t, pd.DataFrame):
y_t.ww.init(logical_types=y_t_logical, semantic_tags=y_t_semantic)
else:
y_t = init_series(
y_t,
logical_type=y_t_logical,
semantic_tags=y_t_semantic,
)
return X_t, y_t
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,32 @@ def fit(self, X, y=None):
if self.time_index is None:
raise ValueError("time_index cannot be None!")

# For the multiseries case, where we only want the start delay lag for the baseline
if isinstance(y, pd.DataFrame):
self.statistically_significant_lags = [self.start_delay]
else:
self.statistically_significant_lags = self._find_significant_lags(
y,
conf_level=self.conf_level,
start_delay=self.start_delay,
max_delay=self.max_delay,
if y is None:
# Set lags to all possible lag values
self.statistically_significant_lags = np.arange(
self.start_delay,
self.start_delay + self.max_delay + 1,
)
else:
# For the multiseries case, each series ID has individualized lag values
if isinstance(y, pd.Series) or isinstance(y, np.ndarray):
y = pd.DataFrame(y)

self.statistically_significant_lags = {}
for column in y.columns:
self.statistically_significant_lags[
column
] = self._find_significant_lags(
y[column],
conf_level=self.conf_level,
start_delay=self.start_delay,
max_delay=self.max_delay,
)
if len(y.columns) == 1:
self.statistically_significant_lags = (
self.statistically_significant_lags[column]
)
return self
return self

@staticmethod
Expand All @@ -160,31 +176,28 @@ def _encode_X_while_preserving_index(X_categorical):
@staticmethod
def _find_significant_lags(y, conf_level, start_delay, max_delay):
all_lags = np.arange(start_delay, start_delay + max_delay + 1)
if y is not None:
# Compute the acf and find its peaks
acf_values, ci_intervals = acf(
y,
nlags=len(y) - 1,
fft=True,
alpha=conf_level,
)
peaks, _ = find_peaks(acf_values)
# Significant lags are the union of:
# 1. the peaks (local maxima) that are significant
# 2. The significant lags among the first 10 lags.
# We then filter the list to be in the range [start_delay, start_delay + max_delay]
index = np.arange(len(acf_values))
significant = np.logical_or(ci_intervals[:, 0] > 0, ci_intervals[:, 1] < 0)
first_significant_10 = index[:10][significant[:10]]
significant_lags = (
set(index[significant]).intersection(peaks).union(first_significant_10)
)
# If no lags are significant get the first lag
significant_lags = sorted(significant_lags.intersection(all_lags)) or [
start_delay,
]
else:
significant_lags = all_lags
# Compute the acf and find its peaks
acf_values, ci_intervals = acf(
y,
nlags=len(y) - 1,
fft=True,
alpha=conf_level,
)
peaks, _ = find_peaks(acf_values)
# Significant lags are the union of:
# 1. the peaks (local maxima) that are significant
# 2. The significant lags among the first 10 lags.
# We then filter the list to be in the range [start_delay, start_delay + max_delay]
index = np.arange(len(acf_values))
significant = np.logical_or(ci_intervals[:, 0] > 0, ci_intervals[:, 1] < 0)
first_significant_10 = index[:10][significant[:10]]
significant_lags = (
set(index[significant]).intersection(peaks).union(first_significant_10)
)
# If no lags are significant get the first lag
significant_lags = sorted(significant_lags.intersection(all_lags)) or [
start_delay,
]
return significant_lags

def _compute_rolling_transforms(self, X, y, original_features):
Expand Down Expand Up @@ -234,7 +247,25 @@ def _delay_df(
col = data[col_name]
if categorical_columns and col_name in categorical_columns:
col = X_categorical[col_name]
for t in self.statistically_significant_lags:
# Lags are stored in a dict for multiseries problems
# Returns the lags corresponding to the series ID value
if isinstance(self.statistically_significant_lags, dict):
from evalml.pipelines.utils import MULTISERIES_SEPARATOR_SYMBOL

col_series_id = (
MULTISERIES_SEPARATOR_SYMBOL
+ col_name.split(MULTISERIES_SEPARATOR_SYMBOL)[-1]
)
for (
series_id_target_name,
lag_list,
) in self.statistically_significant_lags.items():
if series_id_target_name.endswith(col_series_id):
lags = lag_list
break
else:
lags = self.statistically_significant_lags
for t in lags:
lagged_features[self.df_colname_prefix.format(col_name, t)] = col.shift(
t,
)
Expand Down
44 changes: 39 additions & 5 deletions evalml/pipelines/multiseries_regression_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _fit(self, X, y):

self.component_graph.fit(X_unstacked, y_unstacked)
self.input_feature_names = self.component_graph.input_feature_names
self.series_id_target_names = y_unstacked.columns

def predict_in_sample(
self,
Expand Down Expand Up @@ -144,7 +145,7 @@ def predict_in_sample(
]
y_overlapping_features = [
feature
for feature in y_train_unstacked.columns
for feature in self.series_id_target_names
if feature in y_unstacked.columns
]
y_unstacked = y_unstacked[y_overlapping_features]
Expand All @@ -154,7 +155,6 @@ def predict_in_sample(
y_train_unstacked = infer_feature_types(y_train_unstacked)
X_unstacked = infer_feature_types(X_unstacked)
y_unstacked = infer_feature_types(y_unstacked)

unstacked_predictions = super().predict_in_sample(
X_unstacked,
y_unstacked,
Expand All @@ -163,16 +163,50 @@ def predict_in_sample(
objective,
calculating_residuals,
)
unstacked_predictions = unstacked_predictions[
[
series_id_target
for series_id_target in self.series_id_target_names
if series_id_target in unstacked_predictions.columns
]
]

# Add `time_index` column to index for generating stacked datetime column in `stack_data()`
unstacked_predictions.index = X_unstacked[self.time_index]
stacked_predictions = stack_data(
unstacked_predictions,
include_series_id=include_series_id,
include_series_id=True,
series_id_name=self.series_id,
)
# Move datetime index into separate date column to use when merging later
stacked_predictions = stacked_predictions.reset_index(drop=False)

sp_dtypes = {
self.time_index: X[self.time_index].dtype,
self.series_id: X[self.series_id].dtype,
self.input_target_name: y.dtype,
}
stacked_predictions = stacked_predictions.astype(sp_dtypes)

# Order prediction based on input (date, series_id)
output_cols = (
[self.series_id, self.input_target_name]
if include_series_id
else [self.input_target_name]
)
stacked_predictions = pd.merge(
X,
stacked_predictions,
on=[self.time_index, self.series_id],
)[output_cols]

# Index will start at the unstacked index, so we need to reset it to the original index
stacked_predictions.index = X.index
stacked_predictions = infer_feature_types(stacked_predictions)
return stacked_predictions

if not include_series_id:
return infer_feature_types(stacked_predictions[self.input_target_name])
else:
return infer_feature_types(stacked_predictions)

def get_forecast_period(self, X):
"""Generates all possible forecasting time points based on latest data point in X.
Expand Down
Loading

0 comments on commit ba6617a

Please sign in to comment.