Skip to content

Commit

Permalink
store x_cols as a variable instead of attribute (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Apr 29, 2024
1 parent e22c027 commit 24ccc22
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
12 changes: 5 additions & 7 deletions nbs/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@
" self.weights_x: pd.DataFrame = None\n",
" self.freq: str = self.base_freq\n",
" self.drop_uid: bool = False\n",
" self.x_cols: List[str]\n",
" self.input_size: int\n",
" self.model_horizon: int\n",
"\n",
Expand Down Expand Up @@ -561,8 +560,7 @@
" X_df = df.drop(columns='y')\n",
" x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n",
" X_df = self.preprocess_X_df(X_df)\n",
" self.x_cols = x_cols\n",
" return Y_df, X_df\n",
" return Y_df, X_df, x_cols\n",
"\n",
" def dataframes_to_dict(self, Y_df: pd.DataFrame, X_df: pd.DataFrame):\n",
" to_dict_args = {'orient': 'split'}\n",
Expand Down Expand Up @@ -604,7 +602,7 @@
" ):\n",
" df, X_df = self.transform_inputs(df=df, X_df=X_df)\n",
" main_logger.info('Preprocessing dataframes...')\n",
" Y_df, X_df = self.preprocess_dataframes(df=df, X_df=X_df)\n",
" Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=X_df)\n",
" self.set_model_params()\n",
" if self.h > self.model_horizon:\n",
" main_logger.warning(\n",
Expand Down Expand Up @@ -659,7 +657,7 @@
" )\n",
" if 'weights_x' in response_timegpt:\n",
" self.weights_x = pd.DataFrame({\n",
" 'features': self.x_cols,\n",
" 'features': x_cols,\n",
" 'weights': response_timegpt['weights_x'],\n",
" })\n",
" fcst_df = pd.DataFrame(**response_timegpt['forecast'])\n",
Expand Down Expand Up @@ -693,7 +691,7 @@
" # exogenous variables are passed after df \n",
" df, _ = self.transform_inputs(df=df, X_df=None)\n",
" main_logger.info('Preprocessing dataframes...')\n",
" Y_df, X_df = self.preprocess_dataframes(df=df, X_df=None)\n",
" Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=None)\n",
" main_logger.info('Calling Anomaly Detector Endpoint...')\n",
" y, x = self.dataframes_to_dict(Y_df, X_df)\n",
" response_timegpt = self._call_api(\n",
Expand All @@ -709,7 +707,7 @@
" )\n",
" if 'weights_x' in response_timegpt:\n",
" self.weights_x = pd.DataFrame({\n",
" 'features': self.x_cols,\n",
" 'features': x_cols,\n",
" 'weights': response_timegpt['weights_x'],\n",
" })\n",
" anomalies_df = pd.DataFrame(**response_timegpt['forecast'])\n",
Expand Down
12 changes: 5 additions & 7 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(
self.weights_x: pd.DataFrame = None
self.freq: str = self.base_freq
self.drop_uid: bool = False
self.x_cols: List[str]
self.input_size: int
self.model_horizon: int

Expand Down Expand Up @@ -499,8 +498,7 @@ def preprocess_dataframes(
X_df = df.drop(columns="y")
x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list()
X_df = self.preprocess_X_df(X_df)
self.x_cols = x_cols
return Y_df, X_df
return Y_df, X_df, x_cols

def dataframes_to_dict(self, Y_df: pd.DataFrame, X_df: pd.DataFrame):
to_dict_args = {"orient": "split"}
Expand Down Expand Up @@ -546,7 +544,7 @@ def forecast(
):
df, X_df = self.transform_inputs(df=df, X_df=X_df)
main_logger.info("Preprocessing dataframes...")
Y_df, X_df = self.preprocess_dataframes(df=df, X_df=X_df)
Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=X_df)
self.set_model_params()
if self.h > self.model_horizon:
main_logger.warning(
Expand Down Expand Up @@ -602,7 +600,7 @@ def forecast(
if "weights_x" in response_timegpt:
self.weights_x = pd.DataFrame(
{
"features": self.x_cols,
"features": x_cols,
"weights": response_timegpt["weights_x"],
}
)
Expand Down Expand Up @@ -637,7 +635,7 @@ def detect_anomalies(self, df: pd.DataFrame):
# exogenous variables are passed after df
df, _ = self.transform_inputs(df=df, X_df=None)
main_logger.info("Preprocessing dataframes...")
Y_df, X_df = self.preprocess_dataframes(df=df, X_df=None)
Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=None)
main_logger.info("Calling Anomaly Detector Endpoint...")
y, x = self.dataframes_to_dict(Y_df, X_df)
response_timegpt = self._call_api(
Expand All @@ -658,7 +656,7 @@ def detect_anomalies(self, df: pd.DataFrame):
if "weights_x" in response_timegpt:
self.weights_x = pd.DataFrame(
{
"features": self.x_cols,
"features": x_cols,
"weights": response_timegpt["weights_x"],
}
)
Expand Down

0 comments on commit 24ccc22

Please sign in to comment.