diff --git a/nbs/nixtla_client.ipynb b/nbs/nixtla_client.ipynb index 1d204229..7d0af32e 100644 --- a/nbs/nixtla_client.ipynb +++ b/nbs/nixtla_client.ipynb @@ -272,7 +272,6 @@ " self.weights_x: pd.DataFrame = None\n", " self.freq: str = self.base_freq\n", " self.drop_uid: bool = False\n", - " self.x_cols: List[str]\n", " self.input_size: int\n", " self.model_horizon: int\n", "\n", @@ -561,8 +560,7 @@ " X_df = df.drop(columns='y')\n", " x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n", " X_df = self.preprocess_X_df(X_df)\n", - " self.x_cols = x_cols\n", - " return Y_df, X_df\n", + " return Y_df, X_df, x_cols\n", "\n", " def dataframes_to_dict(self, Y_df: pd.DataFrame, X_df: pd.DataFrame):\n", " to_dict_args = {'orient': 'split'}\n", @@ -604,7 +602,7 @@ " ):\n", " df, X_df = self.transform_inputs(df=df, X_df=X_df)\n", " main_logger.info('Preprocessing dataframes...')\n", - " Y_df, X_df = self.preprocess_dataframes(df=df, X_df=X_df)\n", + " Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=X_df)\n", " self.set_model_params()\n", " if self.h > self.model_horizon:\n", " main_logger.warning(\n", @@ -659,7 +657,7 @@ " )\n", " if 'weights_x' in response_timegpt:\n", " self.weights_x = pd.DataFrame({\n", - " 'features': self.x_cols,\n", + " 'features': x_cols,\n", " 'weights': response_timegpt['weights_x'],\n", " })\n", " fcst_df = pd.DataFrame(**response_timegpt['forecast'])\n", @@ -693,7 +691,7 @@ " # exogenous variables are passed after df \n", " df, _ = self.transform_inputs(df=df, X_df=None)\n", " main_logger.info('Preprocessing dataframes...')\n", - " Y_df, X_df = self.preprocess_dataframes(df=df, X_df=None)\n", + " Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=None)\n", " main_logger.info('Calling Anomaly Detector Endpoint...')\n", " y, x = self.dataframes_to_dict(Y_df, X_df)\n", " response_timegpt = self._call_api(\n", @@ -709,7 +707,7 @@ " )\n", " if 'weights_x' in response_timegpt:\n", " self.weights_x = pd.DataFrame({\n", - " 'features': self.x_cols,\n", + " 'features': x_cols,\n", " 'weights': response_timegpt['weights_x'],\n", " })\n", " anomalies_df = pd.DataFrame(**response_timegpt['forecast'])\n", diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index 8d2e2707..212bf813 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -188,7 +188,6 @@ def __init__( self.weights_x: pd.DataFrame = None self.freq: str = self.base_freq self.drop_uid: bool = False - self.x_cols: List[str] self.input_size: int self.model_horizon: int @@ -499,8 +498,7 @@ def preprocess_dataframes( X_df = df.drop(columns="y") x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list() X_df = self.preprocess_X_df(X_df) - self.x_cols = x_cols - return Y_df, X_df + return Y_df, X_df, x_cols def dataframes_to_dict(self, Y_df: pd.DataFrame, X_df: pd.DataFrame): to_dict_args = {"orient": "split"} @@ -546,7 +544,7 @@ def forecast( ): df, X_df = self.transform_inputs(df=df, X_df=X_df) main_logger.info("Preprocessing dataframes...") - Y_df, X_df = self.preprocess_dataframes(df=df, X_df=X_df) + Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=X_df) self.set_model_params() if self.h > self.model_horizon: main_logger.warning( @@ -602,7 +600,7 @@ def forecast( if "weights_x" in response_timegpt: self.weights_x = pd.DataFrame( { - "features": self.x_cols, + "features": x_cols, "weights": response_timegpt["weights_x"], } ) @@ -637,7 +635,7 @@ def detect_anomalies(self, df: pd.DataFrame): # exogenous variables are passed after df df, _ = self.transform_inputs(df=df, X_df=None) main_logger.info("Preprocessing dataframes...") - Y_df, X_df = self.preprocess_dataframes(df=df, X_df=None) + Y_df, X_df, x_cols = self.preprocess_dataframes(df=df, X_df=None) main_logger.info("Calling Anomaly Detector Endpoint...") y, x = self.dataframes_to_dict(Y_df, X_df) response_timegpt = self._call_api( @@ -658,7 +656,7 @@ def detect_anomalies(self, df: pd.DataFrame): if "weights_x" in response_timegpt: self.weights_x = pd.DataFrame( { - "features": self.x_cols, + "features": x_cols, "weights": response_timegpt["weights_x"], } )