From 2ae45068d27d48bb038870dd2cb142f416a9d025 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 24 Jun 2024 17:40:03 +0200 Subject: [PATCH 01/39] enable re-training --- neuralprophet/forecaster.py | 55 +++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index d80fcef14..7ec036c0c 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -962,7 +962,7 @@ def fit( pd.DataFrame metrics with training and potentially evaluation metrics """ - if self.fitted: + if self.fitted and not continue_training: raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") # Configuration @@ -2645,23 +2645,23 @@ def _init_train_loader(self, df, num_workers=0): torch DataLoader """ df, _, _, _ = df_utils.prep_or_copy_df(df) - # if not self.fitted: - self.config_normalization.init_data_params( - df=df, - config_lagged_regressors=self.config_lagged_regressors, - config_regressors=self.config_regressors, - config_events=self.config_events, - config_seasonality=self.config_seasonality, - ) + if not self.fitted: + self.config_normalization.init_data_params( + df=df, + config_lagged_regressors=self.config_lagged_regressors, + config_regressors=self.config_regressors, + config_events=self.config_events, + config_seasonality=self.config_seasonality, + ) df = _normalize(df=df, config_normalization=self.config_normalization) - # if not self.fitted: - if self.config_trend.changepoints is not None: - # scale user-specified changepoint times - df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) + if not self.fitted: + if self.config_trend.changepoints is not None: + # scale user-specified changepoint times + df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) - df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) - self.config_trend.changepoints = df_normalized["t"].values # type: ignore + df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) + self.config_trend.changepoints = df_normalized["t"].values # type: ignore # df_merged, _ = df_utils.join_dataframes(df) # df_merged = df_merged.sort_values("ds") @@ -2740,6 +2740,13 @@ def _train( pd.DataFrame metrics """ + # Test + if continue_training: + checkpoint_path = self.metrics_logger.checkpoint_path + print(checkpoint_path) + checkpoint = torch.load(checkpoint_path) + print(checkpoint.keys()) + # Set up data the training dataloader df, _, _, _ = df_utils.prep_or_copy_df(df) train_loader = self._init_train_loader(df, num_workers) @@ -2748,12 +2755,20 @@ def _train( # Internal flag to check if validation is enabled validation_enabled = df_val is not None - # Init the model, if not continue from checkpoint + # Load model and optimizer state from checkpoint if continue_training is True if continue_training: - raise NotImplementedError( - "Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \ - upcoming releases." - ) + checkpoint_path = self.metrics_logger.checkpoint_path + checkpoint = torch.load(checkpoint_path) + self.model = self._init_model() + # TODO: fix size mismatch for trend.trend_changepoints_t: copying a param with shape torch.Size([11]) from checkpoint, the shape in current model is torch.Size([12]). + self.model.load_state_dict(checkpoint["state_dict"], strict=False) + self.optimizer.load_state_dict(checkpoint["optimizer_states"][0]) + self.trainer.current_epoch = checkpoint["epoch"] + 1 + if "lr_schedulers" in checkpoint: + self.lr_scheduler.load_state_dict(checkpoint["lr_schedulers"][0]) + print(f"Resuming training from epoch {self.trainer.current_epoch}") + # TODO: remove print, checkpoint['lr_schedulers'] + print(f"Resuming training from epoch {self.trainer.current_epoch}") else: self.model = self._init_model() From 900c8d5f1cfdb0da32ebffea2d0b10b9121711b4 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Sat, 29 Jun 2024 14:10:56 +0200 Subject: [PATCH 02/39] update scheduler --- neuralprophet/forecaster.py | 81 ++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 20 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 7ec036c0c..fd372680b 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2654,6 +2654,7 @@ def _init_train_loader(self, df, num_workers=0): config_seasonality=self.config_seasonality, ) + print("Changepoints:", self.config_trend.changepoints) df = _normalize(df=df, config_normalization=self.config_normalization) if not self.fitted: if self.config_trend.changepoints is not None: @@ -2746,11 +2747,10 @@ def _train( print(checkpoint_path) checkpoint = torch.load(checkpoint_path) print(checkpoint.keys()) - - # Set up data the training dataloader - df, _, _, _ = df_utils.prep_or_copy_df(df) - train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(df) # train_loader.dataset + print("Current model trend changepoints:", self.model.trend.trend_changepoints_t) + # self.model = time_net.TimeNet.load_from_checkpoint(checkpoint_path) + # self.model.load_state_dict(checkpoint["state_dict"], strict=False) + print(self.model.train_loader) # Internal flag to check if validation is enabled validation_enabled = df_val is not None @@ -2759,20 +2759,55 @@ def _train( if continue_training: checkpoint_path = self.metrics_logger.checkpoint_path checkpoint = torch.load(checkpoint_path) - self.model = self._init_model() - # TODO: fix size mismatch for trend.trend_changepoints_t: copying a param with shape torch.Size([11]) from checkpoint, the shape in current model is torch.Size([12]). - self.model.load_state_dict(checkpoint["state_dict"], strict=False) - self.optimizer.load_state_dict(checkpoint["optimizer_states"][0]) - self.trainer.current_epoch = checkpoint["epoch"] + 1 - if "lr_schedulers" in checkpoint: - self.lr_scheduler.load_state_dict(checkpoint["lr_schedulers"][0]) - print(f"Resuming training from epoch {self.trainer.current_epoch}") - # TODO: remove print, checkpoint['lr_schedulers'] - print(f"Resuming training from epoch {self.trainer.current_epoch}") + + # Load model state + self.model.load_state_dict(checkpoint["state_dict"]) + + # Adjust epochs + additional_epochs = 10 + previous_epochs = self.config_train.epochs # Get the number of epochs already trained + new_total_epochs = previous_epochs + additional_epochs + self.config_train.epochs = new_total_epochs + + # Reinitialize optimizer with loaded model parameters + optimizer = torch.optim.AdamW(self.model.parameters()) + + # Load optimizer state + if "optimizer_states" in checkpoint and checkpoint["optimizer_states"]: + optimizer.load_state_dict(checkpoint["optimizer_states"][0]) + + self.config_train.optimizer = optimizer + + # Calculate total steps and steps already taken + steps_per_epoch = len(self.model.train_loader) + total_steps = steps_per_epoch * new_total_epochs + steps_taken = steps_per_epoch * previous_epochs + + # Create new scheduler with updated total steps + self.config_train.scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + total_steps=total_steps, + max_lr=10, + pct_start=(total_steps - steps_taken) / total_steps, # Adjust the percentage of remaining steps + ) + + # Manually update the scheduler's step count + for _ in range(steps_taken): + self.config_train.scheduler.step() + + print(f"Scheduler: {self.config_train.scheduler}") + print( + f"Total steps: {total_steps}, Steps taken: {steps_taken}, Remaining steps: {total_steps - steps_taken}" + ) + else: - self.model = self._init_model() + # Set up data the training dataloader + df, _, _, _ = df_utils.prep_or_copy_df(df) + train_loader = self._init_train_loader(df, num_workers) + dataset_size = len(df) # train_loader.dataset - self.model.train_loader = train_loader + self.model = self._init_model() + self.model.train_loader = train_loader # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( @@ -2785,9 +2820,15 @@ def _train( progress_bar_enabled=progress_bar_enabled, metrics_enabled=metrics_enabled, checkpointing_enabled=checkpointing_enabled, - num_batches_per_epoch=len(train_loader), + num_batches_per_epoch=len(self.model.train_loader), ) + # TODO: find out why scheduler not updated + if continue_training: + self.trainer.lr_schedulers = [ + {"scheduler": self.config_train.scheduler, "interval": "step", "frequency": 1} + ] + # Tune hyperparams and train if validation_enabled: # Set up data the validation dataloader @@ -2812,7 +2853,7 @@ def _train( start = time.time() self.trainer.fit( self.model, - train_loader, + self.model.train_loader, val_loader, ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, ) @@ -2834,7 +2875,7 @@ def _train( start = time.time() self.trainer.fit( self.model, - train_loader, + self.model.train_loader, ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, ) From f1355eb5dd37dbcd018f34d0a11d1a338f109b20 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 1 Jul 2024 00:19:54 +0200 Subject: [PATCH 03/39] change scheduler for continued training --- neuralprophet/forecaster.py | 53 ++++++++----------------------------- neuralprophet/time_net.py | 23 +++++++++++----- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index fd372680b..cd7422fe2 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2741,20 +2741,15 @@ def _train( pd.DataFrame metrics """ - # Test - if continue_training: - checkpoint_path = self.metrics_logger.checkpoint_path - print(checkpoint_path) - checkpoint = torch.load(checkpoint_path) - print(checkpoint.keys()) - print("Current model trend changepoints:", self.model.trend.trend_changepoints_t) - # self.model = time_net.TimeNet.load_from_checkpoint(checkpoint_path) - # self.model.load_state_dict(checkpoint["state_dict"], strict=False) - print(self.model.train_loader) # Internal flag to check if validation is enabled validation_enabled = df_val is not None + # Set up data the training dataloader + df, _, _, _ = df_utils.prep_or_copy_df(df) + train_loader = self._init_train_loader(df, num_workers) + dataset_size = len(df) # train_loader.dataset + # Load model and optimizer state from checkpoint if continue_training is True if continue_training: checkpoint_path = self.metrics_logger.checkpoint_path @@ -2763,8 +2758,11 @@ def _train( # Load model state self.model.load_state_dict(checkpoint["state_dict"]) + # Set continue_training flag in model to update scheduler correctly + self.model.continue_training = True + # Adjust epochs - additional_epochs = 10 + additional_epochs = 50 previous_epochs = self.config_train.epochs # Get the number of epochs already trained new_total_epochs = previous_epochs + additional_epochs self.config_train.epochs = new_total_epochs @@ -2778,34 +2776,7 @@ def _train( self.config_train.optimizer = optimizer - # Calculate total steps and steps already taken - steps_per_epoch = len(self.model.train_loader) - total_steps = steps_per_epoch * new_total_epochs - steps_taken = steps_per_epoch * previous_epochs - - # Create new scheduler with updated total steps - self.config_train.scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer=optimizer, - total_steps=total_steps, - max_lr=10, - pct_start=(total_steps - steps_taken) / total_steps, # Adjust the percentage of remaining steps - ) - - # Manually update the scheduler's step count - for _ in range(steps_taken): - self.config_train.scheduler.step() - - print(f"Scheduler: {self.config_train.scheduler}") - print( - f"Total steps: {total_steps}, Steps taken: {steps_taken}, Remaining steps: {total_steps - steps_taken}" - ) - else: - # Set up data the training dataloader - df, _, _, _ = df_utils.prep_or_copy_df(df) - train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(df) # train_loader.dataset - self.model = self._init_model() self.model.train_loader = train_loader @@ -2823,11 +2794,9 @@ def _train( num_batches_per_epoch=len(self.model.train_loader), ) - # TODO: find out why scheduler not updated if continue_training: - self.trainer.lr_schedulers = [ - {"scheduler": self.config_train.scheduler, "interval": "step", "frequency": 1} - ] + print("setting up optimizers again") + # self.trainer.strategy.setup_optimizers(self.trainer) # Tune hyperparams and train if validation_enabled: diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index ea3c4b2f3..5b18525ac 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -63,6 +63,7 @@ def __init__( num_seasonalities_modelled: int = 1, num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, + continue_training: bool = False, ): """ Parameters @@ -306,6 +307,9 @@ def __init__( else: self.config_regressors.regressors = None + # Continued training + self.continue_training = continue_training + @property def ar_weights(self) -> torch.Tensor: """sets property auto-regression weights for regularization. Update if AR is modelled differently""" @@ -863,12 +867,19 @@ def configure_optimizers(self): optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) # Scheduler - lr_scheduler = self._scheduler( - optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - **self.config_train.scheduler_args, - ) + if self.continue_training: + # Update initial learning rate to the last learning rate for continued training + last_lr = optimizer.param_groups[0]["lr"] + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) + for param_group in optimizer.param_groups: + param_group["initial_lr"] = last_lr + else: + lr_scheduler = self._scheduler( + optimizer, + max_lr=self.learning_rate, + total_steps=self.trainer.estimated_stepping_batches, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} From da3a6d5d8a9de442e2f145b8dafbcde9b5507c71 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 1 Jul 2024 17:22:02 +0200 Subject: [PATCH 04/39] add test --- neuralprophet/forecaster.py | 14 +++++++++++--- tests/test_utils.py | 27 +++++++++++++-------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index cd7422fe2..17c256221 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1051,6 +1051,10 @@ def fit( if self.fitted is True and not continue_training: log.error("Model has already been fitted. Re-fitting may break or produce different results.") + + if continue_training and self.metrics_logger.checkpoint_path is None: + log.error("Continued training requires checkpointing in model.") + self.max_lags = df_utils.get_max_num_lags(self.config_lagged_regressors, self.n_lags) if self.max_lags == 0 and self.n_forecasts > 1: self.n_forecasts = 1 @@ -2761,10 +2765,14 @@ def _train( # Set continue_training flag in model to update scheduler correctly self.model.continue_training = True + previous_epoch = checkpoint["epoch"] # Adjust epochs - additional_epochs = 50 - previous_epochs = self.config_train.epochs # Get the number of epochs already trained - new_total_epochs = previous_epochs + additional_epochs + if self.config_train.epochs: + additional_epochs = self.config_train.epochs + else: + additional_epochs = previous_epoch + # Get the number of epochs already trained + new_total_epochs = previous_epoch + additional_epochs self.config_train.epochs = new_total_epochs # Reinitialize optimizer with loaded model parameters diff --git a/tests/test_utils.py b/tests/test_utils.py index f08968e99..c5d838240 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -101,17 +101,16 @@ def test_save_load_io(): pd.testing.assert_frame_equal(forecast, forecast3) -# TODO: add functionality to continue training -# def test_continue_training(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, freq="D") -# metrics2 = m.fit(df, freq="D", continue_training=True) -# assert metrics1["Loss"].sum() >= metrics2["Loss"].sum() +def test_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + metrics2 = m.fit(df, freq="D", continue_training=True) + assert metrics["Loss"].min() >= metrics2["Loss"].min() From f9969285c8c8b8f8f5cf22a6b8850b58323c576a Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 1 Jul 2024 21:04:07 +0200 Subject: [PATCH 05/39] fix metrics logging --- neuralprophet/forecaster.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index c972651e3..5ad4ad73f 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2888,8 +2888,12 @@ def _train( if not metrics_enabled: return None + # Return metrics collected in logger as dataframe - metrics_df = pd.DataFrame(self.metrics_logger.history) + if self.metrics_logger.history is not None: + metrics_df = pd.DataFrame(self.metrics_logger.history) + else: + metrics_df = pd.DataFrame() return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): From f9a77f8da770298d285f7f41f532d897a102316a Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Fri, 5 Jul 2024 10:19:56 +0200 Subject: [PATCH 06/39] include feedback --- neuralprophet/forecaster.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 5ad4ad73f..7cf43d696 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -979,7 +979,12 @@ def fit( metrics with training and potentially evaluation metrics """ if self.fitted and not continue_training: - raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") + raise RuntimeError( + "Model has been fitted already. If you want to continue training please set the flag continue_training." + ) + + if continue_training and epochs is None: + raise ValueError("Continued training requires setting the number of epochs to train for.") # Configuration if epochs is not None: @@ -1065,11 +1070,8 @@ def fit( or any(value != 1 for value in self.num_seasonalities_modelled_dict.values()) ) - if self.fitted is True and not continue_training: - log.error("Model has already been fitted. Re-fitting may break or produce different results.") - if continue_training and self.metrics_logger.checkpoint_path is None: - log.error("Continued training requires checkpointing in model.") + log.error("Continued training requires checkpointing in model to continue from last epoch.") self.max_lags = df_utils.get_max_num_lags( n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors @@ -2777,34 +2779,15 @@ def _train( # Load model and optimizer state from checkpoint if continue_training is True if continue_training: - checkpoint_path = self.metrics_logger.checkpoint_path - checkpoint = torch.load(checkpoint_path) - - # Load model state - self.model.load_state_dict(checkpoint["state_dict"]) + previous_epoch = self.model.current_epoch # Set continue_training flag in model to update scheduler correctly self.model.continue_training = True - previous_epoch = checkpoint["epoch"] # Adjust epochs - if self.config_train.epochs: - additional_epochs = self.config_train.epochs - else: - additional_epochs = previous_epoch - # Get the number of epochs already trained - new_total_epochs = previous_epoch + additional_epochs + new_total_epochs = previous_epoch + self.config_train.epochs self.config_train.epochs = new_total_epochs - # Reinitialize optimizer with loaded model parameters - optimizer = torch.optim.AdamW(self.model.parameters()) - - # Load optimizer state - if "optimizer_states" in checkpoint and checkpoint["optimizer_states"]: - optimizer.load_state_dict(checkpoint["optimizer_states"][0]) - - self.config_train.optimizer = optimizer - else: self.model = self._init_model() From 7ad761d00a93f3b4a9ad54d41116ed0b3c2cb91a Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Fri, 5 Jul 2024 14:56:27 +0200 Subject: [PATCH 07/39] get correct optimizer states --- neuralprophet/configure.py | 4 ++++ neuralprophet/forecaster.py | 6 ++++++ neuralprophet/time_net.py | 17 ++++++++++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 5b54b202e..7a8dcf7c7 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -104,6 +104,7 @@ class Train: n_data: int = field(init=False) loss_func_name: str = field(init=False) lr_finder_args: dict = field(default_factory=dict) + optimizer_state: dict = field(default_factory=dict) def __post_init__(self): # assert the uncertainty estimation params and then finalize the quantiles @@ -239,6 +240,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re delay_weight = 1 return delay_weight + def set_optimizer_state(self, optimizer_state: dict): + self.optimizer_state = optimizer_state + @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 7cf43d696..ed35413aa 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2779,15 +2779,21 @@ def _train( # Load model and optimizer state from checkpoint if continue_training is True if continue_training: + checkpoint_path = self.metrics_logger.checkpoint_path + checkpoint = torch.load(checkpoint_path) + previous_epoch = self.model.current_epoch # Set continue_training flag in model to update scheduler correctly self.model.continue_training = True + self.model.start_epoch = previous_epoch # Adjust epochs new_total_epochs = previous_epoch + self.config_train.epochs self.config_train.epochs = new_total_epochs + self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0]) + else: self.model = self._init_model() diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 8413d8782..e635b9dda 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -64,6 +64,7 @@ def __init__( num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, continue_training: bool = False, + start_epoch: int = 0, ): """ Parameters @@ -309,6 +310,7 @@ def __init__( # Continued training self.continue_training = continue_training + self.start_epoch = start_epoch @property def ar_weights(self) -> torch.Tensor: @@ -870,11 +872,20 @@ def configure_optimizers(self): # Scheduler if self.continue_training: + optimizer.load_state_dict(self.config_train.optimizer_state) + # Update initial learning rate to the last learning rate for continued training - last_lr = optimizer.param_groups[0]["lr"] - lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) + last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float + + batches_per_epoch = len(self.train_dataloader()) + total_batches_processed = self.start_epoch * batches_per_epoch + for param_group in optimizer.param_groups: - param_group["initial_lr"] = last_lr + param_group["initial_lr"] = (last_lr,) + + lr_scheduler = lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, gamma=0.95, last_epoch=total_batches_processed - 1 + ) else: lr_scheduler = self._scheduler( optimizer, From b14d20b2fe9d357d10719b61081eda39132a0345 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Fri, 5 Jul 2024 15:01:07 +0200 Subject: [PATCH 08/39] fix tests --- tests/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index de57fd5fd..a1f8c5874 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,6 +21,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 512 EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 LR = 1.0 BATCH_SIZE = 64 @@ -112,5 +113,5 @@ def test_continue_training(): n_changepoints=0, ) metrics = m.fit(df, checkpointing=True, freq="D") - metrics2 = m.fit(df, freq="D", continue_training=True) + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) assert metrics["Loss"].min() >= metrics2["Loss"].min() From 9fe34012e214a9634c55387822bf26efb269b1a1 Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 8 Jul 2024 11:52:44 +0200 Subject: [PATCH 09/39] enable setting the scheduler --- neuralprophet/configure.py | 65 ++++++++++++++++++++++++++++++------- neuralprophet/forecaster.py | 16 ++++++++- neuralprophet/time_net.py | 19 +++++++---- 3 files changed, 82 insertions(+), 18 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 7a8dcf7c7..79be7fc5a 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -94,7 +94,7 @@ class Train: optimizer: Union[str, Type[torch.optim.Optimizer]] quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) - scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None + scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None scheduler_args: dict = field(default_factory=dict) newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 @@ -193,16 +193,59 @@ def set_scheduler(self): Set the scheduler and scheduler args. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - self.scheduler = torch.optim.lr_scheduler.OneCycleLR - self.scheduler_args.update( - { - "pct_start": 0.3, - "anneal_strategy": "cos", - "div_factor": 10.0, - "final_div_factor": 10.0, - "three_phase": True, - } - ) + self.scheduler_args.clear() + if isinstance(self.scheduler, str): + if self.scheduler.lower() == "onecyclelr": + self.scheduler = torch.optim.lr_scheduler.OneCycleLR + self.scheduler_args.update( + { + "pct_start": 0.3, + "anneal_strategy": "cos", + "div_factor": 10.0, + "final_div_factor": 10.0, + "three_phase": True, + } + ) + elif self.scheduler.lower() == "steplr": + self.scheduler = torch.optim.lr_scheduler.StepLR + self.scheduler_args.update( + { + "step_size": 10, + "gamma": 0.1, + } + ) + elif self.scheduler.lower() == "exponentiallr": + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + self.scheduler_args.update( + { + "gamma": 0.95, + } + ) + elif self.scheduler.lower() == "reducelronplateau": + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau + self.scheduler_args.update( + { + "mode": "min", + "factor": 0.1, + "patience": 10, + } + ) + elif self.scheduler.lower() == "cosineannealinglr": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + self.scheduler_args.update( + { + "T_max": 50, + } + ) + else: + raise NotImplementedError(f"Scheduler {self.scheduler} is not supported.") + elif self.scheduler is None: + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + self.scheduler_args.update( + { + "gamma": 0.95, + } + ) def set_lr_finder_args(self, dataset_size, num_batches): """ diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index ed35413aa..2e12d974e 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -451,6 +451,7 @@ def __init__( accelerator: Optional[str] = None, trainer_config: dict = {}, prediction_frequency: Optional[dict] = None, + scheduler: Optional[str] = "onecyclelr", ): self.config = locals() self.config.pop("self") @@ -509,6 +510,7 @@ def __init__( self.config_train = configure.Train( quantiles=quantiles, learning_rate=learning_rate, + scheduler=scheduler, epochs=epochs, batch_size=batch_size, loss_func=loss_func, @@ -921,6 +923,7 @@ def fit( continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, + scheduler: Optional[str] = None, ): """Train, and potentially evaluate model. @@ -986,6 +989,18 @@ def fit( if continue_training and epochs is None: raise ValueError("Continued training requires setting the number of epochs to train for.") + if continue_training: + if scheduler is not None: + self.config_train.scheduler = scheduler + else: + self.config_train.scheduler = None + self.config_train.set_scheduler() + + if scheduler is not None and not continue_training: + log.warning( + "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model." + ) + # Configuration if epochs is not None: self.config_train.epochs = epochs @@ -2681,7 +2696,6 @@ def _init_train_loader(self, df, num_workers=0): config_seasonality=self.config_seasonality, ) - print("Changepoints:", self.config_trend.changepoints) df = _normalize(df=df, config_normalization=self.config_normalization) if not self.fitted: if self.config_trend.changepoints is not None: diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index e635b9dda..3bbfbe66b 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -883,16 +883,23 @@ def configure_optimizers(self): for param_group in optimizer.param_groups: param_group["initial_lr"] = (last_lr,) - lr_scheduler = lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer, gamma=0.95, last_epoch=total_batches_processed - 1 - ) - else: lr_scheduler = self._scheduler( optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, **self.config_train.scheduler_args, ) + else: + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: + lr_scheduler = self._scheduler( + optimizer, + max_lr=self.learning_rate, + total_steps=self.trainer.estimated_stepping_batches, + **self.config_train.scheduler_args, + ) + else: + lr_scheduler = self._scheduler( + optimizer, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} From 00f2e25e13aca00fb58d8951214fb27ad0da792a Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Mon, 8 Jul 2024 12:31:51 +0200 Subject: [PATCH 10/39] update for onecyclelr --- neuralprophet/forecaster.py | 2 +- neuralprophet/time_net.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 2e12d974e..0a1202c6c 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -436,6 +436,7 @@ def __init__( batch_size: Optional[int] = None, loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss", optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW", + scheduler: Optional[str] = "onecyclelr", newer_samples_weight: float = 2, newer_samples_start: float = 0.0, quantiles: List[float] = [], @@ -451,7 +452,6 @@ def __init__( accelerator: Optional[str] = None, trainer_config: dict = {}, prediction_frequency: Optional[dict] = None, - scheduler: Optional[str] = "onecyclelr", ): self.config = locals() self.config.pop("self") diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 3bbfbe66b..28f4058b5 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -871,35 +871,34 @@ def configure_optimizers(self): optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) # Scheduler + self._scheduler = self.config_train.scheduler + if self.continue_training: optimizer.load_state_dict(self.config_train.optimizer_state) # Update initial learning rate to the last learning rate for continued training last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float - batches_per_epoch = len(self.train_dataloader()) - total_batches_processed = self.start_epoch * batches_per_epoch - for param_group in optimizer.param_groups: param_group["initial_lr"] = (last_lr,) + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: + log.warning("OneCycleLR scheduler is not supported for continued training. Switching to ExponentialLR") + self._scheduler = torch.optim.lr_scheduler.ExponentialLR + self.config_train.scheduler_args = {"gamma": 0.95} + + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: lr_scheduler = self._scheduler( optimizer, + max_lr=self.learning_rate, + total_steps=self.trainer.estimated_stepping_batches, **self.config_train.scheduler_args, ) else: - if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: - lr_scheduler = self._scheduler( - optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - **self.config_train.scheduler_args, - ) - else: - lr_scheduler = self._scheduler( - optimizer, - **self.config_train.scheduler_args, - ) + lr_scheduler = self._scheduler( + optimizer, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} From 5f103d8837fc53fd5639efd260900c8e171341ee Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Tue, 9 Jul 2024 14:01:36 +0200 Subject: [PATCH 11/39] add tests and adapt docstring --- neuralprophet/configure.py | 11 +---------- neuralprophet/forecaster.py | 24 +++++++++++++++++++++++- tests/test_utils.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 79be7fc5a..5cc5edca7 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -190,7 +190,7 @@ def set_optimizer(self): def set_scheduler(self): """ - Set the scheduler and scheduler args. + Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ self.scheduler_args.clear() @@ -221,15 +221,6 @@ def set_scheduler(self): "gamma": 0.95, } ) - elif self.scheduler.lower() == "reducelronplateau": - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau - self.scheduler_args.update( - { - "mode": "min", - "factor": 0.1, - "patience": 10, - } - ) elif self.scheduler.lower() == "cosineannealinglr": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR self.scheduler_args.update( diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 0a1202c6c..05cdc3f75 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -301,6 +301,20 @@ class NeuralProphet: >>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"]) >>> # use custorm torchmetrics names >>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError", + scheduler : str, torch.optim.lr_scheduler._LRScheduler + Type of learning rate scheduler to use. + + Options + * (default) ``OneCycleLR``: One Cycle Learning Rate scheduler + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler + + Examples + -------- + >>> from neuralprophet import NeuralProphet + >>> # Step Learning Rate scheduler + >>> m = NeuralProphet(scheduler="StepLR") COMMENT Uncertainty Estimation @@ -975,6 +989,13 @@ def fit( Note: using multiple workers and therefore distributed training might significantly increase the training time since each batch needs to be copied to each worker for each epoch. Keeping all data on the main process might be faster for most datasets. + scheduler : str + Type of learning rate scheduler to use for continued training. If None, uses ExponentialLR as + default as specified in the model config. + Options + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler Returns ------- @@ -2796,7 +2817,8 @@ def _train( checkpoint_path = self.metrics_logger.checkpoint_path checkpoint = torch.load(checkpoint_path) - previous_epoch = self.model.current_epoch + checkpoint_epoch = checkpoint["epoch"] if "epoch" in checkpoint else 0 + previous_epoch = max(self.model.current_epoch, checkpoint_epoch) # Set continue_training flag in model to update scheduler correctly self.model.continue_training = True diff --git a/tests/test_utils.py b/tests/test_utils.py index a1f8c5874..3b93721bf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -115,3 +115,34 @@ def test_continue_training(): metrics = m.fit(df, checkpointing=True, freq="D") metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_continue_training_with_scheduler_selection(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + # Continue training with StepLR + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_save_load_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + save(m, "test_model.pt") + m2 = load("test_model.pt") + metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() From e04320157766efd517ac3149d62a41d62373b67c Mon Sep 17 00:00:00 2001 From: Constantin Weberpals Date: Tue, 9 Jul 2024 14:22:17 +0200 Subject: [PATCH 12/39] fix array mismatch --- neuralprophet/forecaster.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 05cdc3f75..1c99def2d 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2916,7 +2916,13 @@ def _train( # Return metrics collected in logger as dataframe if self.metrics_logger.history is not None: - metrics_df = pd.DataFrame(self.metrics_logger.history) + # avoid array mismatch when continuing training + history = self.metrics_logger.history + max_length = max(len(lst) for lst in history.values()) + for key in history: + while len(history[key]) < max_length: + history[key].append(None) + metrics_df = pd.DataFrame(history) else: metrics_df = pd.DataFrame() return metrics_df From 63c935c8e4789cb4f3ce948de0b86c23c8d0ddbc Mon Sep 17 00:00:00 2001 From: ourownstory Date: Fri, 23 Aug 2024 17:18:51 -0700 Subject: [PATCH 13/39] robustify scheduler config --- neuralprophet/configure.py | 64 ++++++++++++++++++------------------- neuralprophet/forecaster.py | 44 +++++++++++++++---------- 2 files changed, 59 insertions(+), 49 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index d44d6af81..ee8a442d7 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -94,7 +94,7 @@ class Train: optimizer: Union[str, Type[torch.optim.Optimizer]] quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) - scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None scheduler_args: dict = field(default_factory=dict) newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 @@ -193,50 +193,48 @@ def set_scheduler(self): Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - self.scheduler_args.clear() if isinstance(self.scheduler, str): if self.scheduler.lower() == "onecyclelr": self.scheduler = torch.optim.lr_scheduler.OneCycleLR - self.scheduler_args.update( - { - "pct_start": 0.3, - "anneal_strategy": "cos", - "div_factor": 10.0, - "final_div_factor": 10.0, - "three_phase": True, - } - ) + defaults = { + "pct_start": 0.3, + "anneal_strategy": "cos", + "div_factor": 10.0, + "final_div_factor": 10.0, + "three_phase": True, + } elif self.scheduler.lower() == "steplr": self.scheduler = torch.optim.lr_scheduler.StepLR - self.scheduler_args.update( - { - "step_size": 10, - "gamma": 0.1, - } - ) + defaults = { + "step_size": 10, + "gamma": 0.1, + } elif self.scheduler.lower() == "exponentiallr": self.scheduler = torch.optim.lr_scheduler.ExponentialLR - self.scheduler_args.update( - { - "gamma": 0.95, - } - ) + defaults = { + "gamma": 0.95, + } elif self.scheduler.lower() == "cosineannealinglr": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR - self.scheduler_args.update( - { - "T_max": 50, - } - ) + defaults = { + "T_max": 50, + } else: - raise NotImplementedError(f"Scheduler {self.scheduler} is not supported.") + raise NotImplementedError( + f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class." + ) + if self.scheduler_args is not None: + defaults.update(self.scheduler_args) + self.scheduler_args = defaults elif self.scheduler is None: self.scheduler = torch.optim.lr_scheduler.ExponentialLR - self.scheduler_args.update( - { - "gamma": 0.95, - } - ) + self.scheduler_args = { + "gamma": 0.95, + } + else: # if scheduler is a class + assert issubclass( + self.scheduler, torch.optim.lr_scheduler.LRScheduler + ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" def set_lr_finder_args(self, dataset_size, num_batches): """ diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 3cc386dc2..6caf1cc15 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -298,6 +298,7 @@ class NeuralProphet: >>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"]) >>> # use custorm torchmetrics names >>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError", + scheduler : str, torch.optim.lr_scheduler._LRScheduler Type of learning rate scheduler to use. @@ -446,7 +447,8 @@ def __init__( batch_size: Optional[int] = None, loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss", optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW", - scheduler: Optional[str] = "onecyclelr", + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = "onecyclelr", + scheduler_args: Optional[dict] = None, newer_samples_weight: float = 2, newer_samples_start: float = 0.0, quantiles: List[float] = [], @@ -521,6 +523,7 @@ def __init__( quantiles=quantiles, learning_rate=learning_rate, scheduler=scheduler, + scheduler_args=scheduler_args, epochs=epochs, batch_size=batch_size, loss_func=loss_func, @@ -932,7 +935,8 @@ def fit( continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, - scheduler: Optional[str] = None, + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None, + scheduler_args: Optional[dict] = None, ): """Train, and potentially evaluate model. @@ -1002,20 +1006,30 @@ def fit( "Model has been fitted already. If you want to continue training please set the flag continue_training." ) - if continue_training and epochs is None: - raise ValueError("Continued training requires setting the number of epochs to train for.") - if continue_training: - if scheduler is not None: - self.config_train.scheduler = scheduler - else: + if epochs is None: + raise ValueError("Continued training requires setting the number of epochs to train for.") + + if continue_training and self.metrics_logger.checkpoint_path is None: + log.error("Continued training requires checkpointing in model to continue from last epoch.") + + # if scheduler is not None: + # log.warning( + # "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model." + # ) + + if scheduler is None: + log.warning( + "No scheduler specified for continued training. Using a fallback scheduler for continued training." + ) self.config_train.scheduler = None - self.config_train.set_scheduler() + self.config_train.scheduler_args = None + self.config_train.set_scheduler() - if scheduler is not None and not continue_training: - log.warning( - "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model." - ) + if scheduler is not None: + self.config_train.scheduler = scheduler + self.config_train.scheduler_args = scheduler_args + self.config_train.set_scheduler() # Configuration if epochs is not None: @@ -1061,6 +1075,7 @@ def fit( log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.") if minimal: + # overrides these settings: checkpointing = False self.metrics = False progress = None @@ -1101,9 +1116,6 @@ def fit( or any(value != 1 for value in self.num_seasonalities_modelled_dict.values()) ) - if continue_training and self.metrics_logger.checkpoint_path is None: - log.error("Continued training requires checkpointing in model to continue from last epoch.") - self.max_lags = df_utils.get_max_num_lags( n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors ) From 6a746804c84a565ee05d7d1b9986248c0d4cef70 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Fri, 23 Aug 2024 18:49:18 -0700 Subject: [PATCH 14/39] clean up train config setup --- neuralprophet/configure.py | 88 ++++++++++++++++++++--------- neuralprophet/forecaster.py | 109 +++++++++++++++++++++--------------- neuralprophet/time_net.py | 29 +++++----- 3 files changed, 138 insertions(+), 88 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index ee8a442d7..5d5144fc7 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -105,16 +105,19 @@ class Train: loss_func_name: str = field(init=False) lr_finder_args: dict = field(default_factory=dict) optimizer_state: dict = field(default_factory=dict) + continue_training: bool = False def __post_init__(self): # assert the uncertainty estimation params and then finalize the quantiles - self.set_quantiles() + # self.set_quantiles() assert self.newer_samples_weight >= 1.0 assert self.newer_samples_start >= 0.0 assert self.newer_samples_start < 1.0 self.set_loss_func() - self.set_optimizer() - self.set_scheduler() + + # called in TimeNet configure_optimizers: + # self.set_optimizer() + # self.set_scheduler() def set_loss_func(self): if isinstance(self.loss_func, str): @@ -139,22 +142,22 @@ def set_loss_func(self): if len(self.quantiles) > 1: self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles) - def set_quantiles(self): - # convert quantiles to empty list [] if None - if self.quantiles is None: - self.quantiles = [] - # assert quantiles is a list type - assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." - # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index - self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] - # check if quantiles are float values in (0, 1) - assert all( - 0 < quantile < 1 for quantile in self.quantiles - ), "The quantiles specified need to be floats in-between (0, 1)." - # sort the quantiles - self.quantiles.sort() - # 0 is the median quantile index - self.quantiles.insert(0, 0.5) + # def set_quantiles(self): + # # convert quantiles to empty list [] if None + # if self.quantiles is None: + # self.quantiles = [] + # # assert quantiles is a list type + # assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." + # # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index + # self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] + # # check if quantiles are float values in (0, 1) + # assert all( + # 0 < quantile < 1 for quantile in self.quantiles + # ), "The quantiles specified need to be floats in-between (0, 1)." + # # sort the quantiles + # self.quantiles.sort() + # # 0 is the median quantile index + # self.quantiles.insert(0, 0.5) def set_auto_batch_epoch( self, @@ -183,16 +186,50 @@ def set_optimizer(self): """ Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet. + + Parameters + ---------- + optimizer_name : int + Object provided to NeuralProphet as optimizer. + optimizer_args : dict + Arguments for the optimizer. + """ - self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config( - self.optimizer, self.optimizer_args - ) + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adamw": + # Tends to overfit, but reliable + self.optimizer = torch.optim.AdamW + self.optimizer_args["weight_decay"] = 1e-3 + elif self.optimizer.lower() == "sgd": + # better validation performance, but diverges sometimes + self.optimizer = torch.optim.SGD + self.optimizer_args["momentum"] = 0.9 + self.optimizer_args["weight_decay"] = 1e-4 + else: + raise ValueError( + f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class." + ) + elif not issubclass(self.optimizer, torch.optim.Optimizer): + raise ValueError("The provided optimizer is not supported.") def set_scheduler(self): """ Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ + if self.continue_training: + if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance( + self.scheduler, torch.optim.lr_scheduler.OneCycleLR + ): + log.warning( + "OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler" + ) + self.scheduler = "exponentiallr" + + if self.scheduler is None: + log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.") + self.scheduler = "exponentiallr" + if isinstance(self.scheduler, str): if self.scheduler.lower() == "onecyclelr": self.scheduler = torch.optim.lr_scheduler.OneCycleLR @@ -226,12 +263,7 @@ def set_scheduler(self): if self.scheduler_args is not None: defaults.update(self.scheduler_args) self.scheduler_args = defaults - elif self.scheduler is None: - self.scheduler = torch.optim.lr_scheduler.ExponentialLR - self.scheduler_args = { - "gamma": 0.95, - } - else: # if scheduler is a class + else: assert issubclass( self.scheduler, torch.optim.lr_scheduler.LRScheduler ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 6caf1cc15..ea95834f4 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1,4 +1,5 @@ import logging +import math import os import time from collections import OrderedDict @@ -518,20 +519,36 @@ def __init__( trend_local_reg=trend_local_reg, ) + # Model + self.quantiles = quantiles + # convert quantiles to empty list [] if None + if self.quantiles is None: + self.quantiles = [] + # assert quantiles is a list type + assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." + # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index + self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] + # check if quantiles are float values in (0, 1) + assert all( + 0 < quantile < 1 for quantile in self.quantiles + ), "The quantiles specified need to be floats in-between (0, 1)." + # sort the quantiles + self.quantiles.sort() + # 0 is the median quantile index + self.quantiles.insert(0, 0.5) + # Training - self.config_train = configure.Train( - quantiles=quantiles, - learning_rate=learning_rate, - scheduler=scheduler, - scheduler_args=scheduler_args, - epochs=epochs, - batch_size=batch_size, - loss_func=loss_func, - optimizer=optimizer, - newer_samples_weight=newer_samples_weight, - newer_samples_start=newer_samples_start, - trend_reg_threshold=self.config_trend.trend_reg_threshold, - ) + self.learning_rate = learning_rate + self.scheduler = scheduler + self.scheduler_args = scheduler_args + self.epochs = epochs + self.batch_size = batch_size + self.loss_func = loss_func + self.optimizer = optimizer + self.newer_samples_weight = newer_samples_weight + self.newer_samples_start = newer_samples_start + self.trend_reg_threshold = self.config_trend.trend_reg_threshold + self.continue_training = False # Seasonality self.config_seasonality = configure.ConfigSeasonality( @@ -1013,25 +1030,29 @@ def fit( if continue_training and self.metrics_logger.checkpoint_path is None: log.error("Continued training requires checkpointing in model to continue from last epoch.") - # if scheduler is not None: - # log.warning( - # "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model." - # ) + # Configuration + self.continue_training = continue_training - if scheduler is None: - log.warning( - "No scheduler specified for continued training. Using a fallback scheduler for continued training." - ) - self.config_train.scheduler = None - self.config_train.scheduler_args = None - self.config_train.set_scheduler() + # Config + self.config_train = configure.Train( + quantiles=self.quantiles, + learning_rate=self.learning_rate, + scheduler=self.scheduler, + scheduler_args=self.scheduler_args, + epochs=self.epochs, + batch_size=self.batch_size, + loss_func=self.loss_func, + optimizer=self.optimizer, + newer_samples_weight=self.newer_samples_weight, + newer_samples_start=self.newer_samples_start, + trend_reg_threshold=self.config_trend.trend_reg_threshold, + continue_training=self.continue_training, + ) if scheduler is not None: self.config_train.scheduler = scheduler self.config_train.scheduler_args = scheduler_args - self.config_train.set_scheduler() - # Configuration if epochs is not None: self.config_train.epochs = epochs @@ -1245,7 +1266,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a dates=dates, predicted=predicted, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, components=components, ) if auto_extend and periods_added[df_name] > 0: @@ -1260,7 +1281,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a n_forecasts=self.n_forecasts, max_lags=self.max_lags, freq=self.data_freq, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) if auto_extend and periods_added[df_name] > 0: @@ -1901,7 +1922,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5): else: meta_name_tensor = None - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.quantiles.index(quantile) trend = self.model.trend(t, meta_name_tensor).detach().numpy()[:, :, quantile_index].squeeze() data_params = self.config_normalization.get_data_params(df_name) @@ -1966,7 +1987,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): for name in self.config_seasonality.periods: features = inputs["seasonalities"][name] - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.quantiles.index(quantile) y_season = torch.squeeze( self.model.seasonality.compute_fourier(features=features, name=name, meta=meta_name_tensor)[ :, :, quantile_index @@ -2098,7 +2119,7 @@ def plot( log.info(f"Plotting data from ID {df_name}") if forecast_in_focus is None: forecast_in_focus = self.highlight_forecast_step_n - if len(self.config_train.quantiles) > 1: + if len(self.quantiles) > 1: if (self.highlight_forecast_step_n) is None and ( self.n_forecasts > 1 or self.n_lags > 0 ): # rather query if n_forecasts >1 than n_lags>1 @@ -2138,7 +2159,7 @@ def plot( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, xlabel=xlabel, ylabel=ylabel, figsize=tuple(x * 70 for x in figsize), @@ -2149,7 +2170,7 @@ def plot( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, ax=ax, xlabel=xlabel, ylabel=ylabel, @@ -2217,9 +2238,7 @@ def get_latest_forecast( fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif include_history_data is True: fcst = fcst - fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts - ) + fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts) return fcst def plot_latest_forecast( @@ -2287,7 +2306,7 @@ def plot_latest_forecast( else: fcst = fcst[fcst["ID"] == df_name].copy(deep=True) log.info(f"Plotting data from ID {df_name}") - if len(self.config_train.quantiles) > 1: + if len(self.quantiles) > 1: log.warning( "Plotting latest forecasts when uncertainty estimation enabled" " plots only the median quantile forecasts." @@ -2298,9 +2317,7 @@ def plot_latest_forecast( fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif plot_history_data is True: fcst = fcst - fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts - ) + fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts) # Check whether a local or global plotting backend is set. plotting_backend = select_plotting_backend(model=self, plotting_backend=plotting_backend) @@ -2309,7 +2326,7 @@ def plot_latest_forecast( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, ylabel=ylabel, xlabel=xlabel, figsize=tuple(x * 70 for x in figsize), @@ -2321,7 +2338,7 @@ def plot_latest_forecast( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, ax=ax, ylabel=ylabel, xlabel=xlabel, @@ -2487,7 +2504,7 @@ def plot_components( m=self, fcst=fcst, plot_configuration=valid_plot_configuration, - quantile=self.config_train.quantiles[0], # plot components only for median quantile + quantile=self.quantiles[0], # plot components only for median quantile figsize=figsize, df_name=df_name, one_period_per_season=one_period_per_season, @@ -2597,11 +2614,11 @@ def plot_parameters( if not (0 < quantile < 1): raise ValueError("The quantile selected needs to be a float in-between (0,1)") # ValueError if selected quantile is out of range - if quantile not in self.config_train.quantiles: + if quantile not in self.quantiles: raise ValueError("Selected quantile is not specified in the model configuration.") else: # plot parameters for median quantile if not specified - quantile = self.config_train.quantiles[0] + quantile = self.quantiles[0] # Validate components to be plotted valid_parameters_set = [ @@ -3148,7 +3165,7 @@ def conformal_predict( alpha=alpha, method=method, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.quantiles, ) df_forecast = c.predict(df=df_test, df_cal=df_cal, show_all_PI=show_all_PI) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index c30594e29..e1c3ef8b8 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -158,9 +158,16 @@ def __init__( self.config_normalization = config_normalization self.compute_components_flag = compute_components_flag + # Continued training + self.continue_training = continue_training + self.start_epoch = start_epoch + # Optimizer and LR Scheduler - self._optimizer = self.config_train.optimizer - self._scheduler = self.config_train.scheduler + # self.config_train.set_optimizer() + # self.config_train.set_scheduler() + # self._optimizer = self.config_train.optimizer + # self._scheduler = self.config_train.scheduler + # Manual optimization: we are responsible for calling .backward(), .step(), .zero_grad(). self.automatic_optimization = False # Hyperparameters (can be tuned using trainer.tune()) @@ -314,10 +321,6 @@ def __init__( else: self.config_regressors.regressors = None - # Continued training - self.continue_training = continue_training - self.start_epoch = start_epoch - @property def ar_weights(self) -> torch.Tensor: """sets property auto-regression weights for regularization. Update if AR is modelled differently""" @@ -867,12 +870,14 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): + self.config_train.set_optimizer() + self.config_train.set_scheduler() + self._optimizer = self.config_train.optimizer + self._scheduler = self.config_train.scheduler + # Optimizer optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) - # Scheduler - self._scheduler = self.config_train.scheduler - if self.continue_training: optimizer.load_state_dict(self.config_train.optimizer_state) @@ -882,11 +887,7 @@ def configure_optimizers(self): for param_group in optimizer.param_groups: param_group["initial_lr"] = (last_lr,) - if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: - log.warning("OneCycleLR scheduler is not supported for continued training. Switching to ExponentialLR") - self._scheduler = torch.optim.lr_scheduler.ExponentialLR - self.config_train.scheduler_args = {"gamma": 0.95} - + # Scheduler if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: lr_scheduler = self._scheduler( optimizer, From 420f8a697ffc89338af575068398c6708b735e9a Mon Sep 17 00:00:00 2001 From: ourownstory Date: Sat, 24 Aug 2024 00:57:48 -0700 Subject: [PATCH 15/39] restructure train model config --- neuralprophet/configure.py | 48 ++++++------- neuralprophet/forecaster.py | 138 ++++++++++++++---------------------- neuralprophet/time_net.py | 4 +- tests/test_configure.py | 35 +++------ tests/test_train_config.py | 87 +++++++++++++++++++++++ tests/test_utils.py | 46 ------------ 6 files changed, 178 insertions(+), 180 deletions(-) create mode 100644 tests/test_train_config.py diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 5d5144fc7..00eabd72b 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -23,6 +23,24 @@ @dataclass class Model: lagged_reg_layers: Optional[List[int]] + quantiles: Optional[List[float]] = None + + def setup_quantiles(self): + # convert quantiles to empty list [] if None + if self.quantiles is None: + self.quantiles = [] + # assert quantiles is a list type + assert isinstance(self.quantiles, list), "Quantiles must be provided as list." + # check if quantiles are float values in (0, 1) + assert all( + 0 < quantile < 1 for quantile in self.quantiles + ), "The quantiles specified need to be floats in-between (0, 1)." + # sort the quantiles + self.quantiles.sort() + # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index + self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] + # 0 is the median quantile index + self.quantiles.insert(0, 0.5) @dataclass @@ -92,7 +110,7 @@ class Train: batch_size: Optional[int] loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] optimizer: Union[str, Type[torch.optim.Optimizer]] - quantiles: List[float] = field(default_factory=list) + # quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None scheduler_args: dict = field(default_factory=dict) @@ -106,20 +124,19 @@ class Train: lr_finder_args: dict = field(default_factory=dict) optimizer_state: dict = field(default_factory=dict) continue_training: bool = False + trainer_config: dict = field(default_factory=dict) def __post_init__(self): - # assert the uncertainty estimation params and then finalize the quantiles - # self.set_quantiles() assert self.newer_samples_weight >= 1.0 assert self.newer_samples_start >= 0.0 assert self.newer_samples_start < 1.0 - self.set_loss_func() + # self.set_loss_func(self.quantiles) # called in TimeNet configure_optimizers: # self.set_optimizer() # self.set_scheduler() - def set_loss_func(self): + def set_loss_func(self, quantiles: List[float]): if isinstance(self.loss_func, str): if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]: # keeping 'huber' for backwards compatiblility, though not identical @@ -139,25 +156,8 @@ def set_loss_func(self): self.loss_func_name = type(self.loss_func).__name__ else: raise NotImplementedError(f"Loss function {self.loss_func} not found") - if len(self.quantiles) > 1: - self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles) - - # def set_quantiles(self): - # # convert quantiles to empty list [] if None - # if self.quantiles is None: - # self.quantiles = [] - # # assert quantiles is a list type - # assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." - # # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index - # self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] - # # check if quantiles are float values in (0, 1) - # assert all( - # 0 < quantile < 1 for quantile in self.quantiles - # ), "The quantiles specified need to be floats in-between (0, 1)." - # # sort the quantiles - # self.quantiles.sort() - # # 0 is the median quantile index - # self.quantiles.insert(0, 0.5) + if len(quantiles) > 1: + self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles) def set_auto_batch_epoch( self, diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index ea95834f4..5c1b6d9cd 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -3,6 +3,7 @@ import os import time from collections import OrderedDict +from dataclasses import dataclass, field from typing import Callable, List, Optional, Tuple, Type, Union import matplotlib @@ -452,7 +453,7 @@ def __init__( scheduler_args: Optional[dict] = None, newer_samples_weight: float = 2, newer_samples_start: float = 0.0, - quantiles: List[float] = [], + quantiles: Optional[List[float]] = None, impute_missing: bool = True, impute_linear: int = 10, impute_rolling: int = 10, @@ -463,7 +464,7 @@ def __init__( global_time_normalization: bool = True, unknown_data_normalization: bool = False, accelerator: Optional[str] = None, - trainer_config: dict = {}, + trainer_config: Optional[dict] = None, prediction_frequency: Optional[dict] = None, ): self.config = locals() @@ -505,7 +506,11 @@ def __init__( self.max_lags = self.n_lags # Model - self.config_model = configure.Model(lagged_reg_layers=lagged_reg_layers) + self.config_model = configure.Model( + lagged_reg_layers=lagged_reg_layers, + quantiles=quantiles, + ) + self.config_model.setup_quantiles() # Trend self.config_trend = configure.Trend( @@ -519,24 +524,6 @@ def __init__( trend_local_reg=trend_local_reg, ) - # Model - self.quantiles = quantiles - # convert quantiles to empty list [] if None - if self.quantiles is None: - self.quantiles = [] - # assert quantiles is a list type - assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." - # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index - self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] - # check if quantiles are float values in (0, 1) - assert all( - 0 < quantile < 1 for quantile in self.quantiles - ), "The quantiles specified need to be floats in-between (0, 1)." - # sort the quantiles - self.quantiles.sort() - # 0 is the median quantile index - self.quantiles.insert(0, 0.5) - # Training self.learning_rate = learning_rate self.scheduler = scheduler @@ -586,7 +573,7 @@ def __init__( # Pytorch Lightning Trainer self.metrics_logger = MetricsLogger(save_dir=os.getcwd()) self.accelerator = accelerator - self.trainer_config = trainer_config + self.trainer_config = trainer_config if trainer_config is not None else {} # set during prediction self.future_periods = None @@ -954,6 +941,7 @@ def fit( deterministic: bool = False, scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None, scheduler_args: Optional[dict] = None, + trainer_config: Optional[dict] = None, ): """Train, and potentially evaluate model. @@ -1018,6 +1006,12 @@ def fit( pd.DataFrame metrics with training and potentially evaluation metrics """ + if minimal: + # overrides these settings: + checkpointing = False + self.metrics = False + progress = None + if self.fitted and not continue_training: raise RuntimeError( "Model has been fitted already. If you want to continue training please set the flag continue_training." @@ -1031,43 +1025,25 @@ def fit( log.error("Continued training requires checkpointing in model to continue from last epoch.") # Configuration - self.continue_training = continue_training - - # Config self.config_train = configure.Train( - quantiles=self.quantiles, - learning_rate=self.learning_rate, - scheduler=self.scheduler, - scheduler_args=self.scheduler_args, - epochs=self.epochs, - batch_size=self.batch_size, + learning_rate=self.learning_rate if learning_rate is None else learning_rate, + scheduler=self.scheduler if scheduler is None else scheduler, + scheduler_args=self.scheduler_args if scheduler is None else scheduler_args, + epochs=self.epochs if epochs is None else epochs, + batch_size=self.batch_size if batch_size is None else batch_size, loss_func=self.loss_func, optimizer=self.optimizer, newer_samples_weight=self.newer_samples_weight, newer_samples_start=self.newer_samples_start, trend_reg_threshold=self.config_trend.trend_reg_threshold, - continue_training=self.continue_training, + continue_training=continue_training, + trainer_config=self.trainer_config if trainer_config is None else trainer_config, ) - - if scheduler is not None: - self.config_train.scheduler = scheduler - self.config_train.scheduler_args = scheduler_args - - if epochs is not None: - self.config_train.epochs = epochs - - if batch_size is not None: - self.config_train.batch_size = batch_size - - if learning_rate is not None: - self.config_train.learning_rate = learning_rate + self.config_train.set_loss_func(quantiles=self.config_model.quantiles) if early_stopping is not None: self.early_stopping = early_stopping - if metrics is not None: - self.metrics = utils_metrics.get_metrics(metrics) - # Warnings if early_stopping: reg_enabled = utils.check_for_regularization( @@ -1088,19 +1064,16 @@ def fit( number of epochs to train for." ) - if progress == "plot" and metrics is False: - log.info("Progress plot requires metrics to be enabled. Enabling the default metrics.") - metrics = utils_metrics.get_metrics(True) + if metrics: + self.metrics = utils_metrics.get_metrics(metrics) + + if progress == "plot" and not metrics: + log.info("Progress plot requires metrics to be enabled. Disabling progress plot.") + progress = None if not self.config_normalization.global_normalization: log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.") - if minimal: - # overrides these settings: - checkpointing = False - self.metrics = False - progress = None - # Pre-processing # Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled) df, _, _, self.id_list = df_utils.prep_or_copy_df(df) @@ -1266,7 +1239,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a dates=dates, predicted=predicted, n_forecasts=self.n_forecasts, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, components=components, ) if auto_extend and periods_added[df_name] > 0: @@ -1281,7 +1254,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a n_forecasts=self.n_forecasts, max_lags=self.max_lags, freq=self.data_freq, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) if auto_extend and periods_added[df_name] > 0: @@ -1922,7 +1895,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5): else: meta_name_tensor = None - quantile_index = self.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) trend = self.model.trend(t, meta_name_tensor).detach().numpy()[:, :, quantile_index].squeeze() data_params = self.config_normalization.get_data_params(df_name) @@ -1987,7 +1960,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): for name in self.config_seasonality.periods: features = inputs["seasonalities"][name] - quantile_index = self.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) y_season = torch.squeeze( self.model.seasonality.compute_fourier(features=features, name=name, meta=meta_name_tensor)[ :, :, quantile_index @@ -2119,7 +2092,7 @@ def plot( log.info(f"Plotting data from ID {df_name}") if forecast_in_focus is None: forecast_in_focus = self.highlight_forecast_step_n - if len(self.quantiles) > 1: + if len(self.config_model.quantiles) > 1: if (self.highlight_forecast_step_n) is None and ( self.n_forecasts > 1 or self.n_lags > 0 ): # rather query if n_forecasts >1 than n_lags>1 @@ -2159,7 +2132,7 @@ def plot( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, xlabel=xlabel, ylabel=ylabel, figsize=tuple(x * 70 for x in figsize), @@ -2170,7 +2143,7 @@ def plot( else: return plot( fcst=fcst, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, ax=ax, xlabel=xlabel, ylabel=ylabel, @@ -2238,7 +2211,9 @@ def get_latest_forecast( fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif include_history_data is True: fcst = fcst - fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts) + fcst = utils.fcst_df_to_latest_forecast( + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts + ) return fcst def plot_latest_forecast( @@ -2306,7 +2281,7 @@ def plot_latest_forecast( else: fcst = fcst[fcst["ID"] == df_name].copy(deep=True) log.info(f"Plotting data from ID {df_name}") - if len(self.quantiles) > 1: + if len(self.config_model.quantiles) > 1: log.warning( "Plotting latest forecasts when uncertainty estimation enabled" " plots only the median quantile forecasts." @@ -2317,7 +2292,9 @@ def plot_latest_forecast( fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif plot_history_data is True: fcst = fcst - fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts) + fcst = utils.fcst_df_to_latest_forecast( + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts + ) # Check whether a local or global plotting backend is set. plotting_backend = select_plotting_backend(model=self, plotting_backend=plotting_backend) @@ -2326,7 +2303,7 @@ def plot_latest_forecast( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, ylabel=ylabel, xlabel=xlabel, figsize=tuple(x * 70 for x in figsize), @@ -2338,7 +2315,7 @@ def plot_latest_forecast( else: return plot( fcst=fcst, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, ax=ax, ylabel=ylabel, xlabel=xlabel, @@ -2504,7 +2481,7 @@ def plot_components( m=self, fcst=fcst, plot_configuration=valid_plot_configuration, - quantile=self.quantiles[0], # plot components only for median quantile + quantile=self.config_model.quantiles[0], # plot components only for median quantile figsize=figsize, df_name=df_name, one_period_per_season=one_period_per_season, @@ -2614,11 +2591,11 @@ def plot_parameters( if not (0 < quantile < 1): raise ValueError("The quantile selected needs to be a float in-between (0,1)") # ValueError if selected quantile is out of range - if quantile not in self.quantiles: + if quantile not in self.config_model.quantiles: raise ValueError("Selected quantile is not specified in the model configuration.") else: # plot parameters for median quantile if not specified - quantile = self.quantiles[0] + quantile = self.config_model.quantiles[0] # Validate components to be plotted valid_parameters_set = [ @@ -2686,13 +2663,9 @@ def plot_parameters( ) def _init_model(self): - """Build Pytorch model with configured hyperparamters. - - Returns - ------- - TimeNet model - """ + """Build Pytorch model with configured hyperparamters.""" self.model = time_net.TimeNet( + config_model=self.config_model, config_train=self.config_train, config_trend=self.config_trend, config_ar=self.config_ar, @@ -2715,7 +2688,6 @@ def _init_model(self): meta_used_in_model=self.meta_used_in_model, ) log.debug(self.model) - return self.model def _init_train_loader(self, df, num_workers=0): """Executes data preparation steps and initiates training procedure. @@ -2855,14 +2827,14 @@ def _train( self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0]) else: - self.model = self._init_model() + self._init_model() self.model.train_loader = train_loader # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( config_train=self.config_train, - config=self.trainer_config, + config=self.config_train.trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.early_stopping, early_stopping_target="Loss_val" if validation_enabled else "Loss", @@ -2960,7 +2932,7 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ self.trainer, _ = utils.configure_trainer( config_train=self.config_train, - config=self.trainer_config, + config=self.config_train.trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.early_stopping, accelerator=accelerator, @@ -3165,7 +3137,7 @@ def conformal_predict( alpha=alpha, method=method, n_forecasts=self.n_forecasts, - quantiles=self.quantiles, + quantiles=self.config_model.quantiles, ) df_forecast = c.predict(df=df_test, df_cal=df_cal, show_all_PI=show_all_PI) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index e1c3ef8b8..8f847d56d 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -42,6 +42,7 @@ class TimeNet(pl.LightningModule): def __init__( self, + config_model: configure.Model, config_seasonality: configure.ConfigSeasonality, config_train: Optional[configure.Train] = None, config_trend: Optional[configure.Trend] = None, @@ -151,6 +152,7 @@ def __init__( pass # General + self.config_model = config_model self.n_forecasts = n_forecasts # Lightning Config @@ -209,7 +211,7 @@ def __init__( ) # Quantiles - self.quantiles = self.config_train.quantiles + self.quantiles = self.config_model.quantiles # Trend self.config_trend = config_trend diff --git a/tests/test_configure.py b/tests/test_configure.py index e5c5e9800..a93539e29 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -1,20 +1,6 @@ import pytest -from neuralprophet.configure import Train - - -def generate_config_train_params(overrides={}): - config_train_params = { - "quantiles": None, - "learning_rate": None, - "epochs": None, - "batch_size": None, - "loss_func": "SmoothL1Loss", - "optimizer": "AdamW", - } - for key, value in overrides.items(): - config_train_params[key] = value - return config_train_params +from neuralprophet import NeuralProphet def test_config_training_quantiles(): @@ -26,24 +12,21 @@ def test_config_training_quantiles(): ({"quantiles": [0.2, 0.8]}, [0.5, 0.2, 0.8]), ({"quantiles": [0.5, 0.8]}, [0.5, 0.8]), ] - for overrides, expected in checks: - config_train_params = generate_config_train_params(overrides) - config = Train(**config_train_params) - assert config.quantiles == expected + model = NeuralProphet(**overrides) + assert model.config_model.quantiles == expected def test_config_training_quantiles_error_invalid_type(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = "hello world" with pytest.raises(AssertionError) as err: - Train(**config_train_params) - assert str(err.value) == "Quantiles must be in a list format, not None or scalar." + _ = NeuralProphet(quantiles="hello world") + assert str(err.value) == "Quantiles must be provided as list." def test_config_training_quantiles_error_invalid_scale(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = [-1] with pytest.raises(Exception) as err: - Train(**config_train_params) + _ = NeuralProphet(quantiles=[-1]) + assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." + with pytest.raises(Exception) as err: + _ = NeuralProphet(quantiles=[1.3]) assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." diff --git a/tests/test_train_config.py b/tests/test_train_config.py new file mode 100644 index 000000000..95716365e --- /dev/null +++ b/tests/test_train_config.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import io +import logging +import os +import pathlib + +import pandas as pd +import pytest + +from neuralprophet import NeuralProphet, df_utils, load, save + +log = logging.getLogger("NP.test") +log.setLevel("ERROR") +log.parent.setLevel("ERROR") + +DIR = pathlib.Path(__file__).parent.parent.absolute() +DATA_DIR = os.path.join(DIR, "tests", "test-data") +PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv") +AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv") +YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") +NROWS = 512 +EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 +LR = 1.0 +BATCH_SIZE = 64 + +PLOT = False + + +def generate_config_train_params(overrides={}): + config_train_params = { + "learning_rate": None, + "epochs": None, + "batch_size": None, + "loss_func": "SmoothL1Loss", + "optimizer": "AdamW", + } + for key, value in overrides.items(): + config_train_params[key] = value + return config_train_params + + +def test_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_continue_training_with_scheduler_selection(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + # Continue training with StepLR + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_save_load_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + save(m, "test_model.pt") + m2 = load("test_model.pt") + metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() diff --git a/tests/test_utils.py b/tests/test_utils.py index 3b93721bf..8bed33192 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -100,49 +100,3 @@ def test_save_load_io(): # Check that the forecasts are the same pd.testing.assert_frame_equal(forecast, forecast2) pd.testing.assert_frame_equal(forecast, forecast3) - - -def test_continue_training(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - metrics = m.fit(df, checkpointing=True, freq="D") - metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) - assert metrics["Loss"].min() >= metrics2["Loss"].min() - - -def test_continue_training_with_scheduler_selection(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - metrics = m.fit(df, checkpointing=True, freq="D") - # Continue training with StepLR - metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") - assert metrics["Loss"].min() >= metrics2["Loss"].min() - - -def test_save_load_continue_training(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - metrics = m.fit(df, checkpointing=True, freq="D") - save(m, "test_model.pt") - m2 = load("test_model.pt") - metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") - assert metrics["Loss"].min() >= metrics2["Loss"].min() From 1982089c266611faadfa135e6a6b3d240a8c5112 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 16:28:31 -0700 Subject: [PATCH 16/39] remove continue train --- neuralprophet/configure.py | 19 +---- neuralprophet/forecaster.py | 136 ++++++++++++------------------------ neuralprophet/time_net.py | 61 +++++++--------- tests/test_train_config.py | 99 ++++++++++++++++++-------- 4 files changed, 139 insertions(+), 176 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 00eabd72b..299f02d94 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -114,6 +114,7 @@ class Train: optimizer_args: dict = field(default_factory=dict) scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None scheduler_args: dict = field(default_factory=dict) + early_stopping: Optional[bool] = False newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 reg_delay_pct: float = 0.5 @@ -122,9 +123,7 @@ class Train: n_data: int = field(init=False) loss_func_name: str = field(init=False) lr_finder_args: dict = field(default_factory=dict) - optimizer_state: dict = field(default_factory=dict) - continue_training: bool = False - trainer_config: dict = field(default_factory=dict) + pl_trainer_config: dict = field(default_factory=dict) def __post_init__(self): assert self.newer_samples_weight >= 1.0 @@ -217,14 +216,6 @@ def set_scheduler(self): Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - if self.continue_training: - if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance( - self.scheduler, torch.optim.lr_scheduler.OneCycleLR - ): - log.warning( - "OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler" - ) - self.scheduler = "exponentiallr" if self.scheduler is None: log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.") @@ -289,9 +280,8 @@ def set_lr_finder_args(self, dataset_size, num_batches): } ) - def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0): + def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0): # Ignore type warning of epochs possibly being None (does not work with dataclasses) - progress = (e + iter_progress) / float(self.epochs) # type: ignore if reg_start_pct == reg_full_pct: reg_progress = float(progress > reg_start_pct) else: @@ -304,9 +294,6 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re delay_weight = 1 return delay_weight - def set_optimizer_state(self, optimizer_state: dict): - self.optimizer_state = optimizer_state - @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 5c1b6d9cd..333cc07cb 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -236,7 +236,7 @@ class NeuralProphet: Train Config COMMENT learning_rate : float - Maximum learning rate setting for 1cycle policy scheduler. + Maximum learning rate setting for lr scheduler. Note ---- @@ -313,8 +313,7 @@ class NeuralProphet: Examples -------- >>> from neuralprophet import NeuralProphet - >>> # Step Learning Rate scheduler - >>> m = NeuralProphet(scheduler="StepLR") + >>> m = NeuralProphet(scheduler="ExponentialLR", scheduler_args={"gamma": 0.99}) COMMENT Uncertainty Estimation @@ -379,7 +378,7 @@ class NeuralProphet: select an available accelerator. Provide `None` to deactivate the use of accelerators. trainer_config: dict - Dictionary of additional trainer configuration parameters. + Dictionary of additional Pytorch Lighning Trainer configuration parameters. prediction_frequency: dict Set a periodic interval in which forecasts should be made. @@ -525,17 +524,19 @@ def __init__( ) # Training - self.learning_rate = learning_rate - self.scheduler = scheduler - self.scheduler_args = scheduler_args - self.epochs = epochs - self.batch_size = batch_size - self.loss_func = loss_func - self.optimizer = optimizer - self.newer_samples_weight = newer_samples_weight - self.newer_samples_start = newer_samples_start - self.trend_reg_threshold = self.config_trend.trend_reg_threshold - self.continue_training = False + self.config_train = configure.Train( + learning_rate=learning_rate, + scheduler=scheduler, + scheduler_args=scheduler_args, + epochs=epochs, + batch_size=batch_size, + loss_func=loss_func, + optimizer=optimizer, + newer_samples_weight=newer_samples_weight, + newer_samples_start=newer_samples_start, + early_stopping=False, + pl_trainer_config=trainer_config, + ) # Seasonality self.config_seasonality = configure.ConfigSeasonality( @@ -573,7 +574,6 @@ def __init__( # Pytorch Lightning Trainer self.metrics_logger = MetricsLogger(save_dir=os.getcwd()) self.accelerator = accelerator - self.trainer_config = trainer_config if trainer_config is not None else {} # set during prediction self.future_periods = None @@ -936,7 +936,6 @@ def fit( metrics: Optional[np_types.CollectMetricsMode] = None, progress: Optional[str] = "bar", checkpointing: bool = False, - continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None, @@ -986,8 +985,6 @@ def fit( * `None` checkpointing : bool Flag whether to save checkpoints during training - continue_training : bool - Flag whether to continue training from the last checkpoint num_workers : int Number of workers for data loading. If 0, data will be loaded in the main process. Note: using multiple workers and therefore distributed training might significantly increase @@ -1012,40 +1009,28 @@ def fit( self.metrics = False progress = None - if self.fitted and not continue_training: - raise RuntimeError( - "Model has been fitted already. If you want to continue training please set the flag continue_training." - ) - - if continue_training: - if epochs is None: - raise ValueError("Continued training requires setting the number of epochs to train for.") - - if continue_training and self.metrics_logger.checkpoint_path is None: - log.error("Continued training requires checkpointing in model to continue from last epoch.") - - # Configuration - self.config_train = configure.Train( - learning_rate=self.learning_rate if learning_rate is None else learning_rate, - scheduler=self.scheduler if scheduler is None else scheduler, - scheduler_args=self.scheduler_args if scheduler is None else scheduler_args, - epochs=self.epochs if epochs is None else epochs, - batch_size=self.batch_size if batch_size is None else batch_size, - loss_func=self.loss_func, - optimizer=self.optimizer, - newer_samples_weight=self.newer_samples_weight, - newer_samples_start=self.newer_samples_start, - trend_reg_threshold=self.config_trend.trend_reg_threshold, - continue_training=continue_training, - trainer_config=self.trainer_config if trainer_config is None else trainer_config, - ) - self.config_train.set_loss_func(quantiles=self.config_model.quantiles) - + if self.fitted: + raise RuntimeError("Model has been fitted already.") + + # Train Configuration: overwrite self.config_train with user provided values + if learning_rate is not None: + self.config_train.learning_rate = learning_rate + if scheduler is not None: + self.config_train.scheduler = scheduler + if scheduler_args is not None: + self.config_train.scheduler_args = scheduler_args + if epochs is not None: + self.config_train.epochs = epochs + if batch_size is not None: + self.config_train.batch_size = batch_size + if trainer_config is not None: + self.config_train.pl_trainer_config = trainer_config if early_stopping is not None: - self.early_stopping = early_stopping + self.config_train.early_stopping = early_stopping + self.config_train.set_loss_func(quantiles=self.config_model.quantiles) # Warnings - if early_stopping: + if self.config_train.early_stopping: reg_enabled = utils.check_for_regularization( [ self.config_seasonality, @@ -1128,7 +1113,6 @@ def fit( progress_bar_enabled=bool(progress), metrics_enabled=bool(self.metrics), checkpointing_enabled=checkpointing, - continue_training=continue_training, num_workers=num_workers, deterministic=deterministic, ) @@ -1153,7 +1137,6 @@ def fit( progress_bar_enabled=bool(progress), metrics_enabled=bool(self.metrics), checkpointing_enabled=checkpointing, - continue_training=continue_training, num_workers=num_workers, deterministic=deterministic, ) @@ -2771,7 +2754,6 @@ def _train( progress_bar_enabled: bool = True, metrics_enabled: bool = False, checkpointing_enabled: bool = False, - continue_training=False, num_workers=0, deterministic: bool = False, ): @@ -2790,8 +2772,6 @@ def _train( whether to collect metrics during training checkpointing_enabled : bool whether to save checkpoints during training - continue_training : bool - whether to continue training from the last checkpoint num_workers : int number of workers for data loading @@ -2808,35 +2788,16 @@ def _train( # Internal flag to check if validation is enabled validation_enabled = df_val is not None - # Load model and optimizer state from checkpoint if continue_training is True - if continue_training: - checkpoint_path = self.metrics_logger.checkpoint_path - checkpoint = torch.load(checkpoint_path) - - checkpoint_epoch = checkpoint["epoch"] if "epoch" in checkpoint else 0 - previous_epoch = max(self.model.current_epoch, checkpoint_epoch) - - # Set continue_training flag in model to update scheduler correctly - self.model.continue_training = True - self.model.start_epoch = previous_epoch - - # Adjust epochs - new_total_epochs = previous_epoch + self.config_train.epochs - self.config_train.epochs = new_total_epochs - - self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0]) - - else: - self._init_model() + self._init_model() self.model.train_loader = train_loader # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( config_train=self.config_train, - config=self.config_train.trainer_config, + config=self.config_train.pl_trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.early_stopping, + early_stopping=self.config_train.early_stopping, early_stopping_target="Loss_val" if validation_enabled else "Loss", accelerator=self.accelerator, progress_bar_enabled=progress_bar_enabled, @@ -2852,7 +2813,7 @@ def _train( df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) val_loader = self._init_val_loader(df_val) - if not continue_training and not self.config_train.learning_rate: + if not self.config_train.learning_rate: # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) # Find suitable learning rate @@ -2871,10 +2832,9 @@ def _train( self.model, train_loader, val_loader, - ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, ) else: - if not continue_training and not self.config_train.learning_rate: + if not self.config_train.learning_rate: # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) # Find suitable learning rate @@ -2891,7 +2851,6 @@ def _train( self.trainer.fit( self.model, train_loader, - ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, ) log.debug("Train Time: {:8.3f}".format(time.time() - start)) @@ -2909,16 +2868,7 @@ def _train( return None # Return metrics collected in logger as dataframe - if self.metrics_logger.history is not None: - # avoid array mismatch when continuing training - history = self.metrics_logger.history - max_length = max(len(lst) for lst in history.values()) - for key in history: - while len(history[key]) < max_length: - history[key].append(None) - metrics_df = pd.DataFrame(history) - else: - metrics_df = pd.DataFrame() + metrics_df = pd.DataFrame(self.metrics_logger.history) return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): @@ -2932,9 +2882,9 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ self.trainer, _ = utils.configure_trainer( config_train=self.config_train, - config=self.config_train.trainer_config, + config=self.config_train.pl_trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.early_stopping, + early_stopping=self.self.config_train.early_stopping, accelerator=accelerator, metrics_enabled=bool(self.metrics), ) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 8f847d56d..3f3ae5b9d 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -64,8 +64,6 @@ def __init__( num_seasonalities_modelled: int = 1, num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, - continue_training: bool = False, - start_epoch: int = 0, ): """ Parameters @@ -160,10 +158,6 @@ def __init__( self.config_normalization = config_normalization self.compute_components_flag = compute_components_flag - # Continued training - self.continue_training = continue_training - self.start_epoch = start_epoch - # Optimizer and LR Scheduler # self.config_train.set_optimizer() # self.config_train.set_scheduler() @@ -772,20 +766,22 @@ def loss_func(self, inputs, predicted, targets): loss = None # Compute loss. no reduction. loss = self.config_train.loss_func(predicted, targets) - # Weigh newer samples more. - loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) + if self.config_train.newer_samples_weight > 1.0: + # Weigh newer samples more. + loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) loss = loss.sum(dim=2).mean() # Regularize. if self.reg_enabled: - steps_per_epoch = math.ceil(self.trainer.estimated_stepping_batches / self.trainer.max_epochs) - progress_in_epoch = 1 - ((steps_per_epoch * (self.current_epoch + 1) - self.global_step) / steps_per_epoch) - loss, reg_loss = self._add_batch_regularizations(loss, self.current_epoch, progress_in_epoch) + loss, reg_loss = self._add_batch_regularizations(loss, self.train_progress) else: reg_loss = torch.tensor(0.0, device=self.device) return loss, reg_loss def training_step(self, batch, batch_idx): inputs, targets, meta = batch + self.train_progress = ( + self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch) + ) / self.config_train.epochs # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) @@ -805,7 +801,7 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step() + scheduler.step(epoch=self.train_progress) # Manually track the loss for the lr finder self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) @@ -872,6 +868,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): + self.train_steps_per_epoch = len(self.trainer.train_dataloader) + assert self.train_steps_per_epoch * self.config_train.epochs == self.trainer.num_training_batches + self.config_train.set_optimizer() self.config_train.set_scheduler() self._optimizer = self.config_train.optimizer @@ -880,21 +879,12 @@ def configure_optimizers(self): # Optimizer optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) - if self.continue_training: - optimizer.load_state_dict(self.config_train.optimizer_state) - - # Update initial learning rate to the last learning rate for continued training - last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float - - for param_group in optimizer.param_groups: - param_group["initial_lr"] = (last_lr,) - # Scheduler if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: lr_scheduler = self._scheduler( optimizer, max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, + total_steps=self.config_train.epochs, **self.config_train.scheduler_args, ) else: @@ -907,33 +897,30 @@ def configure_optimizers(self): def _get_time_based_sample_weight(self, t): weight = torch.ones_like(t) - if self.config_train.newer_samples_weight > 1.0: - end_w = self.config_train.newer_samples_weight - start_t = self.config_train.newer_samples_start - time = (t.detach() - start_t) / (1.0 - start_t) - time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1 - time = np.pi * (time - 1.0) # time = -pi to 0 - time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1 - # scales end to be end weight times bigger than start weight - # with end weight being 1.0 - weight = (1.0 + time * (end_w - 1.0)) / end_w + end_w = self.config_train.newer_samples_weight + start_t = self.config_train.newer_samples_start + time = (t.detach() - start_t) / (1.0 - start_t) + time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1 + time = np.pi * (time - 1.0) # time = -pi to 0 + time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1 + # scales end to be end weight times bigger than start weight + # with end weight being 1.0 + weight = (1.0 + time * (end_w - 1.0)) / end_w return weight.unsqueeze(dim=2) # add an extra dimension for the quantiles - def _add_batch_regularizations(self, loss, epoch, progress): + def _add_batch_regularizations(self, loss, progress): """Add regularization terms to loss, if applicable Parameters ---------- loss : torch.Tensor, scalar current batch loss - epoch : int - current epoch number progress : float - progress within the epoch, between 0 and 1 + progress within training, across all epochs and batches, between 0 and 1 Returns ------- loss, reg_loss """ - delay_weight = self.config_train.get_reg_delay_weight(epoch, progress) + delay_weight = self.config_train.get_reg_delay_weight(progress) reg_loss = torch.zeros(1, dtype=torch.float, requires_grad=False, device=self.device) if delay_weight > 0: diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 95716365e..6263315fb 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -41,47 +41,86 @@ def generate_config_train_params(overrides={}): return config_train_params -def test_continue_training(): +def test_custom_lr_scheduler(): df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + + # Set in NeuralProphet() m = NeuralProphet( epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, + scheduler="CosineAnnealingWarmRestarts", + scheduler_args={"T_0": 5, "T_mult": 2}, ) - metrics = m.fit(df, checkpointing=True, freq="D") - metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) - assert metrics["Loss"].min() >= metrics2["Loss"].min() - - -def test_continue_training_with_scheduler_selection(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + metrics = m.fit(df, freq="D") + # Set in NeuralProphet(), no args m = NeuralProphet( epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, + scheduler="StepLR", ) - metrics = m.fit(df, checkpointing=True, freq="D") - # Continue training with StepLR - metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") - assert metrics["Loss"].min() >= metrics2["Loss"].min() + metrics = m.fit(df, freq="D") + # Set in fit() + m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) + metrics = m.fit( + df, + freq="D", + scheduler="ExponentialLR", + scheduler_args={"gamma": 0.95}, + ) -def test_save_load_continue_training(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - n_lags=6, - n_forecasts=3, - n_changepoints=0, + # Set in fit(), no args + m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) + metrics = m.fit( + df, + freq="D", + scheduler="OneCycleLR", ) - metrics = m.fit(df, checkpointing=True, freq="D") - save(m, "test_model.pt") - m2 = load("test_model.pt") - metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") - assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_continue_training_checkpoint(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_continue_training_with_scheduler_selection(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# # Continue training with StepLR +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_save_load_continue_training(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# save(m, "test_model.pt") +# m2 = load("test_model.pt") +# metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() From 99e03555346974e5746d969119faba3165712d2f Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 17:02:52 -0700 Subject: [PATCH 17/39] fix regularization --- neuralprophet/configure.py | 6 +++ neuralprophet/forecaster.py | 4 +- neuralprophet/time_net.py | 4 +- neuralprophet/utils.py | 72 ++++++++++++++++++------------------ tests/test_regularization.py | 2 +- tests/test_unit.py | 4 +- 6 files changed, 50 insertions(+), 42 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 299f02d94..e447eae4e 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -247,6 +247,12 @@ def set_scheduler(self): defaults = { "T_max": 50, } + elif self.scheduler.lower() == "cosineannealingwarmrestarts": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + defaults = { + "T_0": 5, + "T_mult": 2, + } else: raise NotImplementedError( f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class." diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 333cc07cb..608e1e491 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2795,9 +2795,7 @@ def _train( # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( config_train=self.config_train, - config=self.config_train.pl_trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.config_train.early_stopping, early_stopping_target="Loss_val" if validation_enabled else "Loss", accelerator=self.accelerator, progress_bar_enabled=progress_bar_enabled, @@ -2882,7 +2880,7 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ self.trainer, _ = utils.configure_trainer( config_train=self.config_train, - config=self.config_train.pl_trainer_config, + pl_trainer_config=self.config_train.pl_trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.self.config_train.early_stopping, accelerator=accelerator, diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 3f3ae5b9d..8bb17a417 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -868,8 +868,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): - self.train_steps_per_epoch = len(self.trainer.train_dataloader) - assert self.train_steps_per_epoch * self.config_train.epochs == self.trainer.num_training_batches + self.train_steps_per_epoch = len(self.train_loader) + # self.trainer.num_training_batches = self.train_steps_per_epoch * self.config_train.epochs self.config_train.set_optimizer() self.config_train.set_scheduler() diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 62b9e7481..cc5a3ed16 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -823,9 +823,7 @@ def _smooth_loss(loss, beta=0.9): def configure_trainer( config_train: Train, - config: dict, metrics_logger, - early_stopping: bool = False, early_stopping_target: str = "Loss", accelerator: Optional[str] = None, progress_bar_enabled: bool = True, @@ -841,12 +839,8 @@ def configure_trainer( ---------- config_train : Dict dictionary containing the overall training configuration. - config : dict - dictionary containing the custom PyTorch Lightning trainer configuration. metrics_logger : MetricsLogger MetricsLogger object to log metrics to. - early_stopping: bool - If True, early stopping is enabled. early_stopping_target : str Target metric to use for early stopping. accelerator : str @@ -868,52 +862,58 @@ def configure_trainer( checkpoint_callback PyTorch Lightning checkpoint callback to load the best model """ - config = config.copy() + if config_train.pl_trainer_config is None: + config_train.pl_trainer_config = {} + + pl_trainer_config = config_train.pl_trainer_config + # pl_trainer_config = pl_trainer_config.copy() # Set max number of epochs if hasattr(config_train, "epochs"): if config_train.epochs is not None: - config["max_epochs"] = config_train.epochs + pl_trainer_config["max_epochs"] = config_train.epochs # Configure the Ligthing-logs directory - if "default_root_dir" not in config.keys(): - config["default_root_dir"] = os.getcwd() + if "default_root_dir" not in pl_trainer_config.keys(): + pl_trainer_config["default_root_dir"] = os.getcwd() # Accelerator if isinstance(accelerator, str): if (accelerator == "auto" and torch.cuda.is_available()) or accelerator == "gpu": - config["accelerator"] = "gpu" - config["devices"] = -1 + pl_trainer_config["accelerator"] = "gpu" + pl_trainer_config["devices"] = -1 elif (accelerator == "auto" and hasattr(torch.backends, "mps")) or accelerator == "mps": if torch.backends.mps.is_available(): - config["accelerator"] = "mps" - config["devices"] = 1 + pl_trainer_config["accelerator"] = "mps" + pl_trainer_config["devices"] = 1 elif accelerator != "auto": - config["accelerator"] = accelerator - config["devices"] = 1 + pl_trainer_config["accelerator"] = accelerator + pl_trainer_config["devices"] = 1 - if "accelerator" in config: - log.info(f"Using accelerator {config['accelerator']} with {config['devices']} device(s).") + if "accelerator" in pl_trainer_config: + log.info( + f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s)." + ) else: log.info("No accelerator available. Using CPU for training.") # Configure metrics if metrics_enabled: - config["logger"] = metrics_logger + pl_trainer_config["logger"] = metrics_logger else: - config["logger"] = False + pl_trainer_config["logger"] = False - config["deterministic"] = deterministic + pl_trainer_config["deterministic"] = deterministic # Configure callbacks callbacks = [] - has_custom_callbacks = True if "callbacks" in config else False + has_custom_callbacks = True if "callbacks" in pl_trainer_config else False # Configure checkpointing has_modelcheckpoint_callback = ( True if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in config["callbacks"]) + and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in pl_trainer_config["callbacks"]) else False ) if has_modelcheckpoint_callback and not checkpointing_enabled: @@ -930,17 +930,19 @@ def configure_trainer( callbacks.append(checkpoint_callback) else: checkpoint_callback = next( - callback for callback in config["callbacks"] if isinstance(callback, pl.callbacks.ModelCheckpoint) + callback + for callback in pl_trainer_config["callbacks"] + if isinstance(callback, pl.callbacks.ModelCheckpoint) ) else: - config["enable_checkpointing"] = False + pl_trainer_config["enable_checkpointing"] = False checkpoint_callback = None # Configure the progress bar, refresh every epoch has_progressbar_callback = ( True if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in config["callbacks"]) + and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in pl_trainer_config["callbacks"]) else False ) if has_progressbar_callback and not progress_bar_enabled: @@ -953,21 +955,21 @@ def configure_trainer( prog_bar_callback = ProgressBar(refresh_rate=num_batches_per_epoch, epochs=config_train.epochs) callbacks.append(prog_bar_callback) else: - config["enable_progress_bar"] = False + pl_trainer_config["enable_progress_bar"] = False # Early stopping monitor has_earlystopping_callback = ( True if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in config["callbacks"]) + and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in pl_trainer_config["callbacks"]) else False ) - if has_earlystopping_callback and not early_stopping: + if has_earlystopping_callback and not config_train.early_stopping: raise ValueError( "Early stopping is disabled but an EarlyStopping callback is provided. Please enable early stopping or " "remove the callback." ) - if early_stopping: + if config_train.early_stopping: if not metrics_enabled: raise ValueError("Early stopping requires metrics to be enabled.") if not has_earlystopping_callback: @@ -977,13 +979,13 @@ def configure_trainer( callbacks.append(early_stop_callback) if has_custom_callbacks: - config["callbacks"].extend(callbacks) + pl_trainer_config["callbacks"].extend(callbacks) else: - config["callbacks"] = callbacks - config["num_sanity_val_steps"] = 0 - config["enable_model_summary"] = False + pl_trainer_config["callbacks"] = callbacks + pl_trainer_config["num_sanity_val_steps"] = 0 + pl_trainer_config["enable_model_summary"] = False # TODO: Disabling sampler_ddp brings a good speedup in performance, however, check whether this is a good idea # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#replace-sampler-ddp # config["replace_sampler_ddp"] = False - return pl.Trainer(**config), checkpoint_callback + return pl.Trainer(**pl_trainer_config), checkpoint_callback diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 6631a4d43..d1bebb03e 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -68,7 +68,7 @@ def test_regularization_holidays(): daily_seasonality=False, growth="off", ) - m = m.add_country_holidays("US", regularization=0.001) + m = m.add_country_holidays("US", regularization=0.0001) m.fit(df, freq="D") to_reduce = [] diff --git a/tests/test_unit.py b/tests/test_unit.py index 2032ffecb..42f19d218 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -474,6 +474,7 @@ def test_reg_delay(): ) m.fit(df, freq="D") c = m.config_train + # weight, epoch, epoch_iteration_progress for w, e, i in [ (0, 0, 1), (0, 3, 0), @@ -484,7 +485,8 @@ def test_reg_delay(): (1, 7, 1), (1, 8, 0), ]: - weight = c.get_reg_delay_weight(e, i, reg_start_pct=0.5, reg_full_pct=0.8) + progress = float(e + i) / 10.0 + weight = c.get_reg_delay_weight(progress=progress, reg_start_pct=0.5, reg_full_pct=0.8) assert weight == w From 9575de1fa19e7931b3b259aa47953a2e08bdf6ff Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 17:14:06 -0700 Subject: [PATCH 18/39] fix regularization of holidays test --- neuralprophet/forecaster.py | 4 +- tests/test_regularization.py | 6 +- tests/test_save.py | 135 +++++++++++++++++++++++++++++++++++ tests/test_train_config.py | 46 ------------ tests/test_utils.py | 64 +---------------- 5 files changed, 141 insertions(+), 114 deletions(-) create mode 100644 tests/test_save.py diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 608e1e491..2efb42886 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -747,7 +747,7 @@ def add_events( upper_window : int the upper window for the events in the list of events regularization : float - optional scale for regularization strength + optional scale for regularization strength (try values ~0.00001-0.001) mode : str ``additive`` (default) or ``multiplicative``. @@ -806,7 +806,7 @@ def add_country_holidays( upper_window : int the upper window for all the country holidays regularization : float - optional scale for regularization strength + optional scale for regularization strength (try values ~0.00001-0.001) mode : str ``additive`` (default) or ``multiplicative``. """ diff --git a/tests/test_regularization.py b/tests/test_regularization.py index d1bebb03e..793ed53f0 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -61,7 +61,7 @@ def test_regularization_holidays(): m = NeuralProphet( epochs=20, - batch_size=64, + batch_size=32, learning_rate=0.1, yearly_seasonality=False, weekly_seasonality=False, @@ -80,8 +80,8 @@ def test_regularization_holidays(): to_reduce.append(weight_list[0][0][0]) else: to_preserve.append(weight_list[0][0][0]) - # print(to_reduce) - # print(to_preserve) + print(f"To reduce (< 0.2) {to_reduce}") + print(f"To preserve (> 0.5) {to_preserve}") assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 diff --git a/tests/test_save.py b/tests/test_save.py new file mode 100644 index 000000000..1aeab44fe --- /dev/null +++ b/tests/test_save.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +import io +import logging +import os +import pathlib + +import pandas as pd +import pytest + +from neuralprophet import NeuralProphet, load, save + +log = logging.getLogger("NP.test") +log.setLevel("ERROR") +log.parent.setLevel("ERROR") + +DIR = pathlib.Path(__file__).parent.parent.absolute() +DATA_DIR = os.path.join(DIR, "tests", "test-data") +PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv") +AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv") +YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") +NROWS = 512 +EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 +LR = 1.0 +BATCH_SIZE = 64 + +PLOT = False + + +def test_save_load(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + _ = m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=3) + forecast = m.predict(df=future) + log.info("testing: save") + save(m, "test_model.pt") + + log.info("testing: load") + m2 = load("test_model.pt") + forecast2 = m2.predict(df=future) + + m3 = load("test_model.pt", map_location="cpu") + forecast3 = m3.predict(df=future) + + # Check that the forecasts are the same + pd.testing.assert_frame_equal(forecast, forecast2) + pd.testing.assert_frame_equal(forecast, forecast3) + + +def test_save_load_io(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + _ = m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=3) + forecast = m.predict(df=future) + + # Save the model to an in-memory buffer + log.info("testing: save to buffer") + buffer = io.BytesIO() + save(m, buffer) + buffer.seek(0) # Reset buffer position to the beginning + + log.info("testing: load from buffer") + m2 = load(buffer) + forecast2 = m2.predict(df=future) + + buffer.seek(0) # Reset buffer position to the beginning for another load + m3 = load(buffer, map_location="cpu") + forecast3 = m3.predict(df=future) + + # Check that the forecasts are the same + pd.testing.assert_frame_equal(forecast, forecast2) + pd.testing.assert_frame_equal(forecast, forecast3) + + +# def test_continue_training_checkpoint(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_continue_training_with_scheduler_selection(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# # Continue training with StepLR +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_save_load_continue_training(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# save(m, "test_model.pt") +# m2 = load("test_model.pt") +# metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 6263315fb..e1ecbde8b 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -78,49 +78,3 @@ def test_custom_lr_scheduler(): freq="D", scheduler="OneCycleLR", ) - - -# def test_continue_training_checkpoint(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, checkpointing=True, freq="D") -# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) -# assert metrics["Loss"].min() >= metrics2["Loss"].min() - - -# def test_continue_training_with_scheduler_selection(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, checkpointing=True, freq="D") -# # Continue training with StepLR -# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") -# assert metrics["Loss"].min() >= metrics2["Loss"].min() - - -# def test_save_load_continue_training(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, checkpointing=True, freq="D") -# save(m, "test_model.pt") -# m2 = load("test_model.pt") -# metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") -# assert metrics["Loss"].min() >= metrics2["Loss"].min() diff --git a/tests/test_utils.py b/tests/test_utils.py index 8bed33192..2cb4d9cdb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import io import logging import os import pathlib @@ -8,7 +7,7 @@ import pandas as pd import pytest -from neuralprophet import NeuralProphet, df_utils, load, save +from neuralprophet import NeuralProphet, df_utils log = logging.getLogger("NP.test") log.setLevel("ERROR") @@ -39,64 +38,3 @@ def test_create_dummy_datestamps(): m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) _ = m.fit(df_dummy) _ = m.make_future_dataframe(df_dummy, periods=365, n_historic_predictions=True) - - -def test_save_load(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - _ = m.fit(df, freq="D") - future = m.make_future_dataframe(df, periods=3) - forecast = m.predict(df=future) - log.info("testing: save") - save(m, "test_model.pt") - - log.info("testing: load") - m2 = load("test_model.pt") - forecast2 = m2.predict(df=future) - - m3 = load("test_model.pt", map_location="cpu") - forecast3 = m3.predict(df=future) - - # Check that the forecasts are the same - pd.testing.assert_frame_equal(forecast, forecast2) - pd.testing.assert_frame_equal(forecast, forecast3) - - -def test_save_load_io(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - _ = m.fit(df, freq="D") - future = m.make_future_dataframe(df, periods=3) - forecast = m.predict(df=future) - - # Save the model to an in-memory buffer - log.info("testing: save to buffer") - buffer = io.BytesIO() - save(m, buffer) - buffer.seek(0) # Reset buffer position to the beginning - - log.info("testing: load from buffer") - m2 = load(buffer) - forecast2 = m2.predict(df=future) - - buffer.seek(0) # Reset buffer position to the beginning for another load - m3 = load(buffer, map_location="cpu") - forecast3 = m3.predict(df=future) - - # Check that the forecasts are the same - pd.testing.assert_frame_equal(forecast, forecast2) - pd.testing.assert_frame_equal(forecast, forecast3) From 6d76cb01fa38b2a6bb8f7e720421b8115884d7f4 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 17:16:35 -0700 Subject: [PATCH 19/39] address events reg test --- tests/test_regularization.py | 10 ++++++++-- tests/test_save.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 793ed53f0..62240fd22 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -68,7 +68,10 @@ def test_regularization_holidays(): daily_seasonality=False, growth="off", ) - m = m.add_country_holidays("US", regularization=0.0001) + m = m.add_country_holidays( + "US", + regularization=0.0001, + ) m.fit(df, freq="D") to_reduce = [] @@ -100,7 +103,10 @@ def test_regularization_events(): daily_seasonality=False, growth="off", ) - m = m.add_events(["event_%i" % index for index, _ in enumerate(events)], regularization=REGULARIZATION) + m = m.add_events( + ["event_%i" % index for index, _ in enumerate(events)], + regularization=0.1, + ) events_df = pd.concat( [ pd.DataFrame( diff --git a/tests/test_save.py b/tests/test_save.py index 1aeab44fe..0e4c25452 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -133,3 +133,5 @@ def test_save_load_io(): # m2 = load("test_model.pt") # metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") # assert metrics["Loss"].min() >= metrics2["Loss"].min() + +test_save_load() From 19d649702c153bff42c6fdcd6879e4a36cafb0af Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 17:19:27 -0700 Subject: [PATCH 20/39] fixed reg tests --- tests/test_regularization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 62240fd22..0c45a9ffa 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -130,9 +130,9 @@ def test_regularization_events(): to_reduce.append(param.detach().numpy()[0][0]) else: to_preserve.append(param.detach().numpy()[0][0]) - # print(to_reduce) - # print(to_preserve) - assert np.mean(to_reduce) < 0.1 + print(f"To reduce (< 0.2) {to_reduce}") + print(f"To preserve (> 0.5) {to_preserve}") + assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 From 9187f7f76fd0b934ded0bfd4156d4728b9781603 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 17:23:06 -0700 Subject: [PATCH 21/39] fix save --- neuralprophet/forecaster.py | 2 -- tests/test_regularization.py | 8 ++++---- tests/test_save.py | 2 -- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 2efb42886..57bb8eeb2 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2880,9 +2880,7 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ self.trainer, _ = utils.configure_trainer( config_train=self.config_train, - pl_trainer_config=self.config_train.pl_trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.self.config_train.early_stopping, accelerator=accelerator, metrics_enabled=bool(self.metrics), ) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 0c45a9ffa..e5c6d96eb 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -83,8 +83,8 @@ def test_regularization_holidays(): to_reduce.append(weight_list[0][0][0]) else: to_preserve.append(weight_list[0][0][0]) - print(f"To reduce (< 0.2) {to_reduce}") - print(f"To preserve (> 0.5) {to_preserve}") + # print(f"To reduce (< 0.2) {to_reduce}") + # print(f"To preserve (> 0.5) {to_preserve}") assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 @@ -130,8 +130,8 @@ def test_regularization_events(): to_reduce.append(param.detach().numpy()[0][0]) else: to_preserve.append(param.detach().numpy()[0][0]) - print(f"To reduce (< 0.2) {to_reduce}") - print(f"To preserve (> 0.5) {to_preserve}") + # print(f"To reduce (< 0.2) {to_reduce}") + # print(f"To preserve (> 0.5) {to_preserve}") assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 diff --git a/tests/test_save.py b/tests/test_save.py index 0e4c25452..1aeab44fe 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -133,5 +133,3 @@ def test_save_load_io(): # m2 = load("test_model.pt") # metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") # assert metrics["Loss"].min() >= metrics2["Loss"].min() - -test_save_load() From 7a86edf96fcfd088efc9bb90ae8a2560af9a4f42 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 18:13:19 -0700 Subject: [PATCH 22/39] move to debug folder --- tests/{metrics => debug}/debug-energy-price-daily.ipynb | 0 tests/{metrics => debug}/debug-energy-price-hourly.ipynb | 0 tests/{metrics => debug}/debug-yosemite.ipynb | 0 tests/{metrics => debug}/debug_glocal.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/{metrics => debug}/debug-energy-price-daily.ipynb (100%) rename tests/{metrics => debug}/debug-energy-price-hourly.ipynb (100%) rename tests/{metrics => debug}/debug-yosemite.ipynb (100%) rename tests/{metrics => debug}/debug_glocal.py (100%) diff --git a/tests/metrics/debug-energy-price-daily.ipynb b/tests/debug/debug-energy-price-daily.ipynb similarity index 100% rename from tests/metrics/debug-energy-price-daily.ipynb rename to tests/debug/debug-energy-price-daily.ipynb diff --git a/tests/metrics/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb similarity index 100% rename from tests/metrics/debug-energy-price-hourly.ipynb rename to tests/debug/debug-energy-price-hourly.ipynb diff --git a/tests/metrics/debug-yosemite.ipynb b/tests/debug/debug-yosemite.ipynb similarity index 100% rename from tests/metrics/debug-yosemite.ipynb rename to tests/debug/debug-yosemite.ipynb diff --git a/tests/metrics/debug_glocal.py b/tests/debug/debug_glocal.py similarity index 100% rename from tests/metrics/debug_glocal.py rename to tests/debug/debug_glocal.py From ee9e0e4746d0645baab614f67176b3f893aa78b2 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 18:17:18 -0700 Subject: [PATCH 23/39] debugging --- tests/debug/debug-energy-price-hourly.ipynb | 44 +++++++++++++-------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index 14a09c93e..a4cb07914 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -112,9 +112,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING - (py.warnings._showwarnmsg) - /tmp/ipykernel_23581/1728794800.py:8: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + " df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"H\")\n", + "\n" + ] + } + ], "source": [ "df = pd.read_csv(ENERGY_PRICE_DAILY_FILE)\n", "df[\"temp\"] = df[\"temperature\"]\n", @@ -151,25 +161,27 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using CPU\n" + "Using GPU\n" ] }, { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 39\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mGPU\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mif\u001b[39;00m\u001b[38;5;250m \u001b[39muse_gpu\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01melse\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCPU\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Model\u001b[39;00m\n\u001b[0;32m---> 39\u001b[0m m \u001b[38;5;241m=\u001b[39m \u001b[43mNeuralProphet\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtuned_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtrainer_configs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquantile_list\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# Lagged Regressor\u001b[39;00m\n\u001b[1;32m 42\u001b[0m m\u001b[38;5;241m.\u001b[39madd_lagged_regressor(names\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtemp\u001b[39m\u001b[38;5;124m\"\u001b[39m, n_lags\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m33\u001b[39m, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstandardize\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:575\u001b[0m, in \u001b[0;36mNeuralProphet.__init__\u001b[0;34m(self, growth, changepoints, n_changepoints, changepoints_range, trend_reg, trend_reg_threshold, trend_global_local, trend_local_reg, yearly_seasonality, yearly_seasonality_glocal_mode, weekly_seasonality, weekly_seasonality_glocal_mode, daily_seasonality, daily_seasonality_glocal_mode, seasonality_mode, seasonality_reg, season_global_local, seasonality_local_reg, future_regressors_model, future_regressors_layers, n_forecasts, n_lags, ar_layers, ar_reg, lagged_reg_layers, learning_rate, epochs, batch_size, loss_func, optimizer, scheduler, scheduler_args, newer_samples_weight, newer_samples_start, quantiles, impute_missing, impute_linear, impute_rolling, drop_missing, collect_metrics, normalize, global_normalization, global_time_normalization, unknown_data_normalization, accelerator, trainer_config, prediction_frequency)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 574\u001b[0m \u001b[38;5;66;03m# Pytorch Lightning Trainer\u001b[39;00m\n\u001b[0;32m--> 575\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmetrics_logger \u001b[38;5;241m=\u001b[39m MetricsLogger(save_dir\u001b[38;5;241m=\u001b[39m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetcwd\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator \u001b[38;5;241m=\u001b[39m accelerator\n\u001b[1;32m 578\u001b[0m \u001b[38;5;66;03m# set during prediction\u001b[39;00m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory" + ] } ], "source": [ @@ -2521,7 +2533,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0rc1" + "version": "3.10.12" } }, "nbformat": 4, From c3f3c3cbc6a9de74fdb3bd71426e02e88a3a0db7 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 21:21:52 -0700 Subject: [PATCH 24/39] fix custom lr --- neuralprophet/time_net.py | 30 +- tests/debug/debug-energy-price-hourly.ipynb | 1102 +++++++++++++------ 2 files changed, 806 insertions(+), 326 deletions(-) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 8bb17a417..4c2367e13 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -158,11 +158,6 @@ def __init__( self.config_normalization = config_normalization self.compute_components_flag = compute_components_flag - # Optimizer and LR Scheduler - # self.config_train.set_optimizer() - # self.config_train.set_scheduler() - # self._optimizer = self.config_train.optimizer - # self._scheduler = self.config_train.scheduler # Manual optimization: we are responsible for calling .backward(), .step(), .zero_grad(). self.automatic_optimization = False @@ -801,7 +796,8 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step(epoch=self.train_progress) + scheduler.step() + # scheduler.step(epoch=self.train_progress) # Manually track the loss for the lr finder self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) @@ -873,22 +869,25 @@ def configure_optimizers(self): self.config_train.set_optimizer() self.config_train.set_scheduler() - self._optimizer = self.config_train.optimizer - self._scheduler = self.config_train.scheduler # Optimizer - optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) + optimizer = self.config_train.optimizer( + self.parameters(), + lr=self.learning_rate, + **self.config_train.optimizer_args, + ) # Scheduler - if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: - lr_scheduler = self._scheduler( + if self.config_train.scheduler == torch.optim.lr_scheduler.OneCycleLR: + lr_scheduler = self.config_train.scheduler( optimizer, max_lr=self.learning_rate, - total_steps=self.config_train.epochs, + total_steps=self.trainer.estimated_stepping_batches, + # total_steps=self.config_train.epochs, # if using self.lr_schedulers().step(epoch=self.train_progress) **self.config_train.scheduler_args, ) else: - lr_scheduler = self._scheduler( + lr_scheduler = self.config_train.scheduler( optimizer, **self.config_train.scheduler_args, ) @@ -896,7 +895,6 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} def _get_time_based_sample_weight(self, t): - weight = torch.ones_like(t) end_w = self.config_train.newer_samples_weight start_t = self.config_train.newer_samples_start time = (t.detach() - start_t) / (1.0 - start_t) @@ -906,7 +904,9 @@ def _get_time_based_sample_weight(self, t): # scales end to be end weight times bigger than start weight # with end weight being 1.0 weight = (1.0 + time * (end_w - 1.0)) / end_w - return weight.unsqueeze(dim=2) # add an extra dimension for the quantiles + # add an extra dimension for the quantiles + weight = weight.unsqueeze(dim=2) + return weight def _add_batch_regularizations(self, loss, progress): """Add regularization terms to loss, if applicable diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index a4cb07914..d81254f4e 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -112,19 +112,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING - (py.warnings._showwarnmsg) - /tmp/ipykernel_23581/1728794800.py:8: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", - " df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"H\")\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "df = pd.read_csv(ENERGY_PRICE_DAILY_FILE)\n", "df[\"temp\"] = df[\"temperature\"]\n", @@ -133,7 +123,7 @@ "df[\"y\"] = pd.to_numeric(df[\"y\"], errors=\"coerce\")\n", "\n", "df = df.drop(\"ds\", axis=1)\n", - "df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"H\")\n", + "df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"h\")\n", "df[\"ID\"] = \"test\"\n", "\n", "df_id = df[[\"ds\", \"y\", \"temp\"]].copy()\n", @@ -156,32 +146,35 @@ "df[\"temp\"] = (df[\"temp\"] - 65.0) / 50.0\n", "\n", "# df\n", - "df = df[[\"ID\", \"ds\", \"y\", \"temp\", \"winter\", \"summer\"]]" + "df = df[[\"ID\", \"ds\", \"y\", \"temp\", \"winter\", \"summer\"]]\n", + "\n", + "# Split\n", + "df_train = df[df[\"ds\"] < \"2015-03-01\"]\n", + "df_test = df[df[\"ds\"] >= \"2015-03-01\"]" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using GPU\n" + "quantiles: [0.01, 0.99]\n", + "Using CPU\n" ] }, { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 39\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mGPU\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mif\u001b[39;00m\u001b[38;5;250m \u001b[39muse_gpu\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01melse\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCPU\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Model\u001b[39;00m\n\u001b[0;32m---> 39\u001b[0m m \u001b[38;5;241m=\u001b[39m \u001b[43mNeuralProphet\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtuned_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtrainer_configs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquantile_list\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# Lagged Regressor\u001b[39;00m\n\u001b[1;32m 42\u001b[0m m\u001b[38;5;241m.\u001b[39madd_lagged_regressor(names\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtemp\u001b[39m\u001b[38;5;124m\"\u001b[39m, n_lags\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m33\u001b[39m, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstandardize\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:575\u001b[0m, in \u001b[0;36mNeuralProphet.__init__\u001b[0;34m(self, growth, changepoints, n_changepoints, changepoints_range, trend_reg, trend_reg_threshold, trend_global_local, trend_local_reg, yearly_seasonality, yearly_seasonality_glocal_mode, weekly_seasonality, weekly_seasonality_glocal_mode, daily_seasonality, daily_seasonality_glocal_mode, seasonality_mode, seasonality_reg, season_global_local, seasonality_local_reg, future_regressors_model, future_regressors_layers, n_forecasts, n_lags, ar_layers, ar_reg, lagged_reg_layers, learning_rate, epochs, batch_size, loss_func, optimizer, scheduler, scheduler_args, newer_samples_weight, newer_samples_start, quantiles, impute_missing, impute_linear, impute_rolling, drop_missing, collect_metrics, normalize, global_normalization, global_time_normalization, unknown_data_normalization, accelerator, trainer_config, prediction_frequency)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 574\u001b[0m \u001b[38;5;66;03m# Pytorch Lightning Trainer\u001b[39;00m\n\u001b[0;32m--> 575\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmetrics_logger \u001b[38;5;241m=\u001b[39m MetricsLogger(save_dir\u001b[38;5;241m=\u001b[39m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetcwd\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator \u001b[38;5;241m=\u001b[39m accelerator\n\u001b[1;32m 578\u001b[0m \u001b[38;5;66;03m# set during prediction\u001b[39;00m\n", - "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -212,9 +205,12 @@ "# Uncertainty Quantification\n", "confidence_lv = 0.98\n", "quantile_list = [round(((1 - confidence_lv) / 2), 2), round((confidence_lv + (1 - confidence_lv) / 2), 2)]\n", + "# quantile_list = None\n", + "print(f\"quantiles: {quantile_list}\")\n", "\n", "# Check if GPU is available\n", - "use_gpu = torch.cuda.is_available()\n", + "# use_gpu = torch.cuda.is_available()\n", + "use_gpu = False\n", "\n", "# Set trainer configuration\n", "trainer_configs = {\n", @@ -246,50 +242,96 @@ "output_type": "stream", "text": [ "INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.929% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.929% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", + "\n", "INFO - (NP.utils.configure_trainer) - Using accelerator cpu with 1 device(s).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "aa26aaf9191f401b9c69ebafca381bab", + "model_id": "b5912ecccdbe4255b43d1751767fc1e8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training: 0it [00:00, ?it/s]" + "Training: | | 0/? [00:00= \"2015-03-01\"]\n", - "\n", "# Training & Predict\n", - "metrics = m.fit(df=df_train, validation_df=df_test, freq=\"H\", num_workers=4, early_stopping=False)" + "metrics = m.fit(df=df_train, validation_df=df_test, freq=\"h\", early_stopping=False)" ] }, { @@ -430,29 +482,77 @@ }, "data": [ { - "legendgroup": "MAE", + "legendgroup": "train_loss", "line": { "color": "#2d92ff", "width": 2 }, "mode": "lines", - "name": "MAE", + "name": "train_loss", "type": "scatter", "xaxis": "x", "y": [ - 1.6991313695907593, - 1.5541504621505737, - 1.2866111993789673, - 1.0485198497772217, - 0.9603586792945862, - 0.933108389377594, - 0.9244528412818909, - 0.9177840948104858, - 0.9132021069526672, - 0.9105463027954102 + 3.6281654834747314, + 3.216404438018799, + 2.461292028427124, + 1.84268319606781, + 1.5482017993927002, + 1.4468858242034912, + 1.4149123430252075, + 1.3923537731170654, + 1.378030776977539, + 1.3715811967849731 ], "yaxis": "y" }, + { + "legendgroup": "reg_loss", + "line": { + "color": "#2d92ff", + "width": 2 + }, + "mode": "lines", + "name": "reg_loss", + "type": "scatter", + "xaxis": "x2", + "y": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + "yaxis": "y2" + }, + { + "legendgroup": "MAE", + "line": { + "color": "#2d92ff", + "width": 2 + }, + "mode": "lines", + "name": "MAE", + "type": "scatter", + "xaxis": "x3", + "y": [ + 1.945920467376709, + 1.764636516571045, + 1.454800009727478, + 1.2135502099990845, + 1.1105183362960815, + 1.076404094696045, + 1.0656453371047974, + 1.0584872961044312, + 1.05453622341156, + 1.0535459518432617 + ], + "yaxis": "y3" + }, { "legendgroup": "MAE", "line": { @@ -462,20 +562,20 @@ "mode": "lines", "name": "MAE_val", "type": "scatter", - "xaxis": "x", + "xaxis": "x3", "y": [ - 1.9174306392669678, - 2.133635997772217, - 2.1361277103424072, - 1.954904317855835, - 1.8205108642578125, - 1.7834810018539429, - 1.7635681629180908, - 1.7493915557861328, - 1.7418491840362549, - 1.7389646768569946 + 2.0952579975128174, + 1.5834287405014038, + 1.1268068552017212, + 0.9564144015312195, + 0.879956841468811, + 0.8555251359939575, + 0.8443053364753723, + 0.8364248275756836, + 0.8316343426704407, + 0.8298080563545227 ], - "yaxis": "y" + "yaxis": "y3" }, { "legendgroup": "RMSE", @@ -486,20 +586,20 @@ "mode": "lines", "name": "RMSE", "type": "scatter", - "xaxis": "x2", + "xaxis": "x4", "y": [ - 2.249849557876587, - 2.062807083129883, - 1.6801131963729858, - 1.344346523284912, - 1.2270969152450562, - 1.1934525966644287, - 1.1826142072677612, - 1.1741188764572144, - 1.169130563735962, - 1.1649360656738281 + 2.4845407009124756, + 2.2576544284820557, + 1.8687987327575684, + 1.557382583618164, + 1.4208526611328125, + 1.3762744665145874, + 1.3620320558547974, + 1.3528945446014404, + 1.3467621803283691, + 1.3467226028442383 ], - "yaxis": "y2" + "yaxis": "y4" }, { "legendgroup": "RMSE", @@ -510,20 +610,20 @@ "mode": "lines", "name": "RMSE_val", "type": "scatter", - "xaxis": "x2", + "xaxis": "x4", "y": [ - 2.1282451152801514, - 2.287360668182373, - 2.3184731006622314, - 2.140346050262451, - 2.0008866786956787, - 1.962218999862671, - 1.9410110712051392, - 1.9257516860961914, - 1.9175572395324707, - 1.914405107498169 + 2.6508853435516357, + 2.058405876159668, + 1.3784314393997192, + 1.0836613178253174, + 0.9949031472206116, + 0.9689829349517822, + 0.9556937217712402, + 0.9464864134788513, + 0.941169023513794, + 0.9391241669654846 ], - "yaxis": "y2" + "yaxis": "y4" }, { "legendgroup": "Loss", @@ -534,20 +634,20 @@ "mode": "lines", "name": "Loss", "type": "scatter", - "xaxis": "x3", + "xaxis": "x5", "y": [ - 3.4565775394439697, - 3.047083854675293, - 2.3058581352233887, - 1.710412621498108, - 1.4448997974395752, - 1.353717565536499, - 1.3267676830291748, - 1.3102833032608032, - 1.2921112775802612, - 1.2888280153274536 + 3.6250030994415283, + 3.1991159915924072, + 2.4528629779815674, + 1.8357555866241455, + 1.5451138019561768, + 1.4480032920837402, + 1.4154229164123535, + 1.3918782472610474, + 1.3788602352142334, + 1.3751496076583862 ], - "yaxis": "y3" + "yaxis": "y5" }, { "legendgroup": "Loss", @@ -558,20 +658,20 @@ "mode": "lines", "name": "Loss_val", "type": "scatter", - "xaxis": "x3", + "xaxis": "x5", "y": [ - 4.821254730224609, - 4.705277919769287, - 4.240411758422852, - 3.7221953868865967, - 3.4264442920684814, - 3.345188617706299, - 3.2992584705352783, - 3.2648608684539795, - 3.246990919113159, - 3.2401645183563232 + 5.441538333892822, + 4.21077823638916, + 2.9864273071289062, + 2.4115242958068848, + 2.1774377822875977, + 2.101806163787842, + 2.061872959136963, + 2.03498911857605, + 2.0202457904815674, + 2.0146193504333496 ], - "yaxis": "y3" + "yaxis": "y5" }, { "legendgroup": "Loss", @@ -582,7 +682,7 @@ "mode": "lines", "name": "RegLoss", "type": "scatter", - "xaxis": "x3", + "xaxis": "x5", "y": [ 0, 0, @@ -595,18 +695,44 @@ 0, 0 ], - "yaxis": "y3" + "yaxis": "y5" } ], "layout": { "annotations": [ + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "train_loss", + "x": 0.08399999999999999, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "reg_loss", + "x": 0.292, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, { "font": { "size": 16 }, "showarrow": false, "text": "MAE", - "x": 0.14444444444444446, + "x": 0.5, "xanchor": "center", "xref": "paper", "y": 1, @@ -619,7 +745,7 @@ }, "showarrow": false, "text": "RMSE", - "x": 0.5, + "x": 0.708, "xanchor": "center", "xref": "paper", "y": 1, @@ -632,7 +758,7 @@ }, "showarrow": false, "text": "Loss", - "x": 0.8555555555555556, + "x": 0.9159999999999999, "xanchor": "center", "xref": "paper", "y": 1, @@ -1478,7 +1604,7 @@ "anchor": "y", "domain": [ 0, - 0.2888888888888889 + 0.16799999999999998 ], "linewidth": 1.5, "mirror": true, @@ -1488,8 +1614,8 @@ "xaxis2": { "anchor": "y2", "domain": [ - 0.35555555555555557, - 0.6444444444444445 + 0.208, + 0.376 ], "linewidth": 1.5, "mirror": true, @@ -1499,7 +1625,29 @@ "xaxis3": { "anchor": "y3", "domain": [ - 0.7111111111111111, + 0.416, + 0.584 + ], + "linewidth": 1.5, + "mirror": true, + "showgrid": false, + "showline": true + }, + "xaxis4": { + "anchor": "y4", + "domain": [ + 0.624, + 0.792 + ], + "linewidth": 1.5, + "mirror": true, + "showgrid": false, + "showline": true + }, + "xaxis5": { + "anchor": "y5", + "domain": [ + 0.832, 1 ], "linewidth": 1.5, @@ -1545,6 +1693,32 @@ "showgrid": false, "showline": true, "type": "log" + }, + "yaxis4": { + "anchor": "x4", + "domain": [ + 0, + 1 + ], + "linewidth": 1.5, + "mirror": true, + "rangemode": "tozero", + "showgrid": false, + "showline": true, + "type": "log" + }, + "yaxis5": { + "anchor": "x5", + "domain": [ + 0, + 1 + ], + "linewidth": 1.5, + "mirror": true, + "rangemode": "tozero", + "showgrid": false, + "showline": true, + "type": "log" } } } @@ -1565,14 +1739,16 @@ { "data": { "text/plain": [ - "{'MAE_val': 1.7389646768569946,\n", - " 'RMSE_val': 1.914405107498169,\n", - " 'Loss_val': 3.2401645183563232,\n", + "{'MAE_val': 0.8298080563545227,\n", + " 'RMSE_val': 0.9391241669654846,\n", + " 'Loss_val': 2.0146193504333496,\n", " 'RegLoss_val': 0.0,\n", " 'epoch': 9,\n", - " 'MAE': 0.9105463027954102,\n", - " 'RMSE': 1.1649360656738281,\n", - " 'Loss': 1.2888280153274536,\n", + " 'train_loss': 1.3715811967849731,\n", + " 'reg_loss': 0.0,\n", + " 'MAE': 1.0535459518432617,\n", + " 'RMSE': 1.3467226028442383,\n", + " 'Loss': 1.3751496076583862,\n", " 'RegLoss': 0.0}" ] }, @@ -1616,6 +1792,8 @@ " Loss_val\n", " RegLoss_val\n", " epoch\n", + " train_loss\n", + " reg_loss\n", " MAE\n", " RMSE\n", " Loss\n", @@ -1625,14 +1803,16 @@ " \n", " \n", " 9\n", - " 1.738965\n", - " 1.914405\n", - " 3.240165\n", + " 0.829808\n", + " 0.939124\n", + " 2.014619\n", " 0.0\n", " 9\n", - " 0.910546\n", - " 1.164936\n", - " 1.288828\n", + " 1.371581\n", + " 0.0\n", + " 1.053546\n", + " 1.346723\n", + " 1.37515\n", " 0.0\n", " \n", " \n", @@ -1640,11 +1820,11 @@ "" ], "text/plain": [ - " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", - "9 1.738965 1.914405 3.240165 0.0 9 0.910546 1.164936 \n", + " MAE_val RMSE_val Loss_val RegLoss_val epoch train_loss reg_loss \\\n", + "9 0.829808 0.939124 2.014619 0.0 9 1.371581 0.0 \n", "\n", - " Loss RegLoss \n", - "9 1.288828 0.0 " + " MAE RMSE Loss RegLoss \n", + "9 1.053546 1.346723 1.37515 0.0 " ] }, "execution_count": 9, @@ -1665,47 +1845,115 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.data.processing._handle_missing_data) - Dropped 5 rows at the end with NaNs in 'y' column.\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.data.processing._handle_missing_data) - Dropped 5 rows at the end with NaNs in 'y' column.\n" + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning:\n", + "\n", + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "\n", + "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "78600faef98442c3bcae260cf6a78232", + "model_id": "abd147333b4244fba6e0e860ab561ce6", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Predicting: 22it [00:00, ?it/s]" + "Predicting: | | 0/? [00:00[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'aebc484d-c130-47bd-8870-268071f0b3d5',\n", + " 'uid': '0f472e8d-3931-40ec-842b-d3902dd90c7f',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", - " dtype=float32)},\n", + " 'y': array([-15.375717 , -35.083336 , -22.973457 , ..., -5.606529 , 7.961891 ,\n", + " 3.7938385], dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'c62aca2a-cbb9-4e43-915f-156387e57092',\n", + " 'uid': 'bd2eba39-2e42-4d93-a783-a1a3a3642743',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([80.960884, 76.19124 , 64.98064 , ..., 55.83882 , 67.100685, 64.74074 ],\n", + " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'aeae0371-af61-428b-bac3-3d7c9675a881',\n", + " 'uid': '478c4639-805c-421f-8836-0f1aa7be7e40',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", + " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'fdc61ccb-a79c-4487-bdf9-b9be5d0159d8',\n", + " 'uid': '8e5d4660-de04-418d-a685-aa43e180cdcb',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", + " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'f3b8bafe-c1a6-4a00-b6d8-94ad845ee178',\n", + " 'uid': '39b3c783-2ac7-4e92-bb1a-d11d10d52e56',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -1834,7 +2137,7 @@ "})" ] }, - "execution_count": 13, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -1846,20 +2149,170 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "INFO - (NP.forecaster.plot_components) - Plotting data from ID test\n" + "INFO - (NP.forecaster.plot_components) - Plotting data from ID test\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:559: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0851c9188ffb4c94bc7948103985aee2", + "model_id": "c3aa3b2182364855a6dd57d92d26b188", "version_major": 2, "version_minor": 0 }, @@ -1870,15 +2323,15 @@ " 'name': '[R] Trend ~1h',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': 'a971f8c1-1e2e-428f-bbae-b1e366f40f84',\n", + " 'uid': '8afd9c2b-62cb-46cf-b9da-bd117f6aadc6',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([41.138184, 41.136326, 41.134468, ..., 38.49218 , 38.488464, 38.486603],\n", + " 'y': array([45.118847, 45.12286 , 45.126873, ..., 50.832302, 50.836315, 50.84434 ],\n", " dtype=float32),\n", " 'yaxis': 'y'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", @@ -1886,15 +2339,15 @@ " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '896ec7a5-4db3-4572-9615-c92ff0a440c2',\n", + " 'uid': 'd491b3bb-5ab7-4d51-8782-51fbe760fe55',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([7.7610316, 7.77278 , 7.782315 , ..., 8.327494 , 8.318201 , 8.315492 ],\n", + " 'y': array([4.2724757, 4.277932 , 4.2831407, ..., 3.0124106, 2.9952903, 2.9674723],\n", " dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", @@ -1902,7 +2355,7 @@ " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '4e7558c0-22c3-42b6-b38b-64828e95911f',\n", + " 'uid': '2228f11e-ec27-44da-afe3-a7a367df4fb8',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", @@ -1910,15 +2363,15 @@ " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x3',\n", - " 'y': array([ 0.36878857, 0.30485797, 0.2463306 , ..., -0.56539005, 0.4600458 ,\n", - " 0.93207777], dtype=float32),\n", + " 'y': array([-6.360024 , -6.2361574 , -6.04062 , ..., -0.17524393, -0.68096423,\n", + " -0.951675 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '421f3dfd-0361-48bd-b035-27b85837e7d1',\n", + " 'uid': '6502f74d-bf1c-4c6e-89ce-33caec7a779d',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", @@ -1926,7 +2379,7 @@ " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x4',\n", - " 'y': array([ 6.8369484 , 8.779529 , -0.55572075, ..., 0. , 0. ,\n", + " 'y': array([-0.6189856 , 1.0478094 , -0.94150484, ..., 0. , 0. ,\n", " 0. ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", @@ -1934,23 +2387,23 @@ " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '34dff3a4-3054-4a91-b79e-f965aa8d3284',\n", + " 'uid': '41f9a971-74fd-4abe-879a-38b10a8e6354',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x5',\n", - " 'y': array([0. , 0. , 0. , ..., 2.5935924, 7.5037613, 4.810857 ],\n", - " dtype=float32),\n", + " 'y': array([ 0. , 0. , 0. , ..., 2.742541, -12.349686,\n", + " -12.63406 ], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': 'cd6c048f-2b7f-47c5-8f34-50aed056ee0b',\n", + " 'uid': '683da2d1-526e-4dbd-81ca-db6d8026e942',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", @@ -1958,23 +2411,23 @@ " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x6',\n", - " 'y': array([14.265438 , 6.3923936 , -0.08357577, ..., 0. , 0.4089267 ,\n", - " 4.4793005 ], dtype=float32),\n", + " 'y': array([ -2.0847676, -11.245036 , -10.022209 , ..., -5.338226 , -4.647217 ,\n", + " -4.156648 ], dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '61f91c94-2944-4efb-9907-709ee2fb6c77',\n", + " 'uid': '21765c1c-4b8b-4e56-8cc9-e129bec394d5',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", " 'xaxis': 'x7',\n", - " 'y': array([1.765334 , 3.7883697, 4.1204934, ..., 1.8360679, 1.8920995, 0. ],\n", + " 'y': array([1.0558217, 2.7553616, 6.266397 , ..., 4.6533093, 9.819243 , 0. ],\n", " dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", @@ -1982,7 +2435,7 @@ " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '885dd23f-b889-4004-9dcf-70c1ca00a53d',\n", + " 'uid': '089c5b7d-4f26-4ebe-91e1-0f4943efc2e3',\n", " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", " datetime.datetime(2015, 1, 2, 10, 0),\n", " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", @@ -1999,16 +2452,16 @@ " 'name': '[R] yhat5 1.0% ~1h',\n", " 'showlegend': True,\n", " 'type': 'scatter',\n", - " 'uid': 'b180a45d-ef7d-43be-9180-b0d60ffffff7',\n", + " 'uid': '57a063db-b1cc-4868-8208-9a193411c395',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", + " datetime.datetime(2015, 3, 2, 13, 0),\n", + " datetime.datetime(2015, 3, 2, 14, 0),\n", + " datetime.datetime(2015, 3, 2, 16, 0)], dtype=object),\n", " 'xaxis': 'x9',\n", - " 'y': array([ 37.975266 , 31.622068 , -3.693409 , ..., -14.12719 , -2.9433403,\n", - " -0.9949646], dtype=float32),\n", + " 'y': array([-45.265923, -77.71274 , -55.634865, ..., -31.16211 , -29.599195,\n", + " -61.93689 ], dtype=float32),\n", " 'yaxis': 'y9'},\n", " {'fill': 'tozeroy',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", @@ -2017,16 +2470,16 @@ " 'name': '[R] yhat5 99.0% ~1h',\n", " 'showlegend': True,\n", " 'type': 'scatter',\n", - " 'uid': '94ed8664-debf-4a84-b7bf-74bc18ea9a66',\n", + " 'uid': 'bd0275b8-e4ef-4b65-a4dd-dcedabd36797',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", + " datetime.datetime(2015, 3, 2, 13, 0),\n", + " datetime.datetime(2015, 3, 2, 15, 0),\n", + " datetime.datetime(2015, 3, 2, 16, 0)], dtype=object),\n", " 'xaxis': 'x9',\n", - " 'y': array([56.57814 , 48.912025 , 12.067997 , ..., -8.972313 , 5.452839 ,\n", - " 5.3340225], dtype=float32),\n", + " 'y': array([ 2.0333786, -13.141644 , 7.119713 , ..., 5.760723 , -2.4633446,\n", + " 2.6906433], dtype=float32),\n", " 'yaxis': 'y9'}],\n", " 'layout': {'autosize': True,\n", " 'barmode': 'overlay',\n", @@ -2176,7 +2629,7 @@ "})" ] }, - "execution_count": 15, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -2187,13 +2640,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:178: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:475: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:508: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:564: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "653b60479a0244c394b8e68cea26b341", + "model_id": "d00edbb989d64c0e8f4ad128abd5121d", "version_major": 2, "version_minor": 0 }, @@ -2204,18 +2683,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': 'f6f21f4d-8199-49f7-a49a-9951dd269bd9',\n", + " 'uid': 'c4a57947-5fce-4aa9-ac50-1dafd620b2e9',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([41.1995 , 38.57022], dtype=float32),\n", + " 'y': array([44.986443, 50.663788], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': 'f0adc090-2190-4a3b-9c9b-97c8cede02e2',\n", + " 'uid': '2f8db94a-fecf-45cc-a9f3-264ac82ad7b7',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2223,15 +2702,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([4.0829487 , 5.187225 , 6.208157 , ..., 0.19168049, 1.4080983 ,\n", - " 2.6177309 ], dtype=float32),\n", + " 'y': array([4.156685 , 4.008716 , 3.7399096, ..., 3.741331 , 4.0178666, 4.1622863],\n", + " dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': '17c4c727-ac96-43b0-8e36-acc3caab9c2d',\n", + " 'uid': 'f75097e7-bc0a-4be8-9663-d49b2a9be7e9',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2245,117 +2724,118 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([-4.4598384 , -4.2069 , -3.939087 , -3.623846 , -3.2824173 ,\n", - " -2.955236 , -2.6178632 , -2.2747667 , -1.9053857 , -1.5644891 ,\n", - " -1.2317951 , -0.9106357 , -0.58533347, -0.28182796, -0.02269452,\n", - " 0.21292625, 0.42130318, 0.6071059 , 0.7497983 , 0.8622262 ,\n", - " 0.9443581 , 0.99504983, 1.0032893 , 0.9738326 , 0.9262827 ,\n", - " 0.85167503, 0.74685544, 0.6115421 , 0.4558556 , 0.2885901 ,\n", - " 0.11641604, -0.07034495, -0.27396792, -0.47582942, -0.66030794,\n", - " -0.8382371 , -1.0191802 , -1.1867032 , -1.3330028 , -1.4465324 ,\n", - " -1.5369219 , -1.6096568 , -1.6520382 , -1.6572822 , -1.6354212 ,\n", - " -1.5726112 , -1.4764075 , -1.3572443 , -1.2015358 , -1.011647 ,\n", - " -0.80462766, -0.5479135 , -0.2638637 , 0.02303137, 0.33831146,\n", - " 0.64510536, 0.9981383 , 1.3464724 , 1.6830701 , 2.0331054 ,\n", - " 2.3555207 , 2.69649 , 3.008644 , 3.2858677 , 3.5552442 ,\n", - " 3.7721593 , 3.9756846 , 4.1423197 , 4.264748 , 4.3529058 ,\n", - " 4.389515 , 4.3708434 , 4.3175454 , 4.2146673 , 4.072025 ,\n", - " 3.8747027 , 3.6363995 , 3.340743 , 3.0221694 , 2.664729 ,\n", - " 2.2731943 , 1.8473105 , 1.4062825 , 0.9366955 , 0.42765933,\n", - " -0.07621501, -0.56814903, -1.059046 , -1.5562105 , -2.0614126 ,\n", - " -2.5020652 , -2.9485521 , -3.352801 , -3.7507663 , -4.1075444 ,\n", - " -4.387843 , -4.6441846 , -4.8436375 , -4.9982753 , -5.098551 ,\n", - " -5.1361027 , -5.115603 , -5.042491 , -4.910203 , -4.7105756 ,\n", - " -4.472313 , -4.1741037 , -3.8366807 , -3.4286702 , -2.970053 ,\n", - " -2.5127845 , -2.0011377 , -1.4785788 , -0.89798415, -0.33239934,\n", - " 0.24315366, 0.8609969 , 1.4341363 , 2.0365245 , 2.5912929 ,\n", - " 3.1261334 , 3.631675 , 4.12216 , 4.5922604 , 4.9916644 ,\n", - " 5.357276 , 5.6596603 , 5.918557 , 6.127722 , 6.2664776 ,\n", - " 6.3465624 , 6.3701463 , 6.3298607 , 6.2247863 , 6.0684857 ,\n", - " 5.8579984 , 5.5947313 , 5.2599936 , 4.8867702 , 4.4837136 ,\n", - " 4.040872 , 3.5451972 , 3.0077634 , 2.463368 , 1.9197478 ,\n", - " 1.3607975 , 0.7956455 , 0.19293702, -0.40335596, -0.9517328 ,\n", - " -1.4785621 , -1.9893605 , -2.4982615 , -2.9505756 , -3.3692324 ,\n", - " -3.7429237 , -4.1006875 , -4.402756 , -4.6468496 , -4.8368144 ,\n", - " -4.987736 , -5.0905485 , -5.1344304 , -5.1309776 , -5.0739946 ,\n", - " -4.993806 , -4.8492026 , -4.6650367 ], dtype=float32),\n", + " 'y': array([ 1.391791 , 1.5258877 , 1.6792349 , 1.8562642 , 2.0394247 ,\n", + " 2.2404115 , 2.4551697 , 2.6826897 , 2.9358835 , 3.1815054 ,\n", + " 3.4400578 , 3.7049842 , 3.9879084 , 4.277275 , 4.5464516 ,\n", + " 4.8153987 , 5.07863 , 5.349201 , 5.5994563 , 5.8297772 ,\n", + " 6.0581393 , 6.244349 , 6.4293222 , 6.584266 , 6.7048225 ,\n", + " 6.796993 , 6.861701 , 6.8956547 , 6.895379 , 6.8557863 ,\n", + " 6.78101 , 6.675751 , 6.522954 , 6.3411875 , 6.124332 ,\n", + " 5.8741636 , 5.589541 , 5.257053 , 4.9091725 , 4.538206 ,\n", + " 4.1433353 , 3.7214267 , 3.2801328 , 2.829382 , 2.3595853 ,\n", + " 1.8662068 , 1.3673176 , 0.89529383, 0.41826385, -0.06061759,\n", + " -0.5093681 , -0.98713917, -1.4474032 , -1.8622016 , -2.2607856 ,\n", + " -2.6188028 , -2.9754944 , -3.2915318 , -3.5731025 , -3.8194983 ,\n", + " -4.0445375 , -4.230131 , -4.3740654 , -4.4832726 , -4.552248 ,\n", + " -4.598398 , -4.600944 , -4.570165 , -4.5096326 , -4.415453 ,\n", + " -4.2919707 , -4.1594696 , -3.9919715 , -3.8121579 , -3.6086745 ,\n", + " -3.3993592 , -3.1723168 , -2.9296591 , -2.6987505 , -2.4613705 ,\n", + " -2.2308369 , -2.0005875 , -1.7844186 , -1.5757424 , -1.3716754 ,\n", + " -1.1903827 , -1.0321187 , -0.8927177 , -0.77494335, -0.6736823 ,\n", + " -0.600088 , -0.5510596 , -0.5266437 , -0.52750367, -0.55236495,\n", + " -0.59258753, -0.6633728 , -0.75148875, -0.8628553 , -0.9960426 ,\n", + " -1.140825 , -1.2935942 , -1.4584701 , -1.6384959 , -1.8310778 ,\n", + " -2.0071678 , -2.199288 , -2.383454 , -2.5767555 , -2.763159 ,\n", + " -2.9236434 , -3.0845432 , -3.2274694 , -3.3643992 , -3.4772263 ,\n", + " -3.573151 , -3.6550226 , -3.71116 , -3.7483535 , -3.7640781 ,\n", + " -3.759362 , -3.732831 , -3.6836677 , -3.6116889 , -3.5259597 ,\n", + " -3.4178932 , -3.295802 , -3.152984 , -2.9906476 , -2.8273013 ,\n", + " -2.6527996 , -2.4705598 , -2.267154 , -2.0591342 , -1.8623275 ,\n", + " -1.6648436 , -1.4681892 , -1.2616874 , -1.0673261 , -0.88591146,\n", + " -0.7106802 , -0.5368361 , -0.37239236, -0.22490174, -0.0935105 ,\n", + " 0.02904331, 0.1401226 , 0.24383257, 0.33426642, 0.41001436,\n", + " 0.47123942, 0.52750564, 0.57017094, 0.6104162 , 0.6333694 ,\n", + " 0.66995436, 0.69541276, 0.7135638 , 0.74077713, 0.7568518 ,\n", + " 0.79820424, 0.83649653, 0.87729394, 0.9278828 , 0.98950917,\n", + " 1.0756317 , 1.1707866 , 1.2736064 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '22616063-3eff-4306-a74a-dc4965de0de9',\n", + " 'uid': '30ca1705-d912-4c2c-9e4d-6bc89a92c81a',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([-4.292418 , -3.5483618, -3.0230176, ..., -5.4796743, -5.2587185,\n", - " -4.8447485], dtype=float32),\n", + " 'y': array([ -9.994224 , -9.253917 , -8.485407 , ..., -11.398036 , -11.11813 ,\n", + " -10.5702505], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': '277a9945-45f4-45ca-91c5-7e8e3f338811',\n", + " 'uid': '217a32cc-0d50-4274-9fda-4c86f2e9946c',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([-1.6798731 , -2.3781397 , -2.901272 , ..., -0.19541107, -0.51879483,\n", - " -1.117872 ], dtype=float32),\n", + " 'y': array([-4.8032146, -4.148358 , -3.5782793, ..., -5.6449375, -5.535998 ,\n", + " -5.2187395], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': 'fa756dff-2aaa-47cc-b16e-9266593c0172',\n", + " 'uid': '98c7beb5-f685-4279-a089-5f435861c65c',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([-0.03951903, 0.41645312, 0.02179232, -0.2604984 , -0.06300073,\n", - " -0.06662486, -0.08233377, -0.03597524, 0.08927898, -0.07381544],\n", + " 'y': array([ 0.43352383, 0.91227597, -0.95198816, -0.40500033, -0.20690662,\n", + " -0.5716205 , -0.04570548, -0.60042465, 0.64957386, -0.23974809],\n", " dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': '041656ea-aaa1-4e1b-abce-a1e942a626bb',\n", + " 'uid': 'a6f7f1dc-e6d7-4e80-88a7-d9bd4eb5a407',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 0.20808354, 0.3050754 , 0.52504927, 0.02816455, -0.2267277 ,\n", - " -0.18377087, 0.34080964, 0.00188361, -0.14284115, 0.06430382,\n", - " 0.31131235, -0.09880974, 0.06406495, 0.25881714, 0.08779721,\n", - " -0.18321382, 0.2451885 , -0.23906691, -0.233605 , -0.05307174,\n", - " 0.17820123, 0.12141816, 0.0911953 , -0.10566162, 0.07743413,\n", - " 0.21802229, 0.35458654, 0.06151056, 0.23792064, -0.12219968,\n", - " -0.2825721 , -0.09865767, 0.25742164], dtype=float32),\n", + " 'y': array([ 0.1342093 , 0.03340701, -0.3030344 , -0.17586076, -0.09225646,\n", + " 0.13190955, -0.09886458, -0.17618474, -0.17881632, 0.07780991,\n", + " 0.05573834, -0.04498276, 0.02436313, -0.10179332, 0.10800336,\n", + " -0.04481242, -0.06880342, 0.24861242, -0.12319569, 0.40785006,\n", + " -0.17237887, -0.2162305 , 0.1412588 , 0.00572657, 0.04435053,\n", + " -0.25764507, 0.22107005, -0.01027014, -0.14096323, 0.0986762 ,\n", + " -0.09228688, 0.45199636, 0.23324972], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': '849f3240-aaa3-4374-b3fb-c1bd7cc14cfa',\n", + " 'uid': '8b6d6ad4-0b5c-406c-b3ca-d4ac4c5bb6fe',\n", " 'width': 0.8,\n", " 'x': array(['Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", - " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", - " \"Washington's Birthday_-1\", 'Christmas Day_+0', 'Christmas Day_+1',\n", - " 'Christmas Day_-1', 'Thanksgiving_+0', 'Thanksgiving_+1',\n", - " 'Thanksgiving_-1', 'Martin Luther King Jr. Day_+0',\n", - " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", - " 'Memorial Day_+0', 'Memorial Day_+1', 'Memorial Day_-1',\n", " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1', 'Independence Day_+0',\n", - " 'Independence Day_+1', 'Independence Day_-1', 'Columbus Day_+0',\n", - " 'Columbus Day_+1', 'Columbus Day_-1'], dtype=object),\n", + " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1',\n", + " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", + " \"Washington's Birthday_-1\", 'Independence Day_+0',\n", + " 'Independence Day_+1', 'Independence Day_-1',\n", + " 'Martin Luther King Jr. Day_+0', 'Martin Luther King Jr. Day_+1',\n", + " 'Martin Luther King Jr. Day_-1', 'Thanksgiving_+0', 'Thanksgiving_+1',\n", + " 'Thanksgiving_-1', 'Christmas Day_+0', 'Christmas Day_+1',\n", + " 'Christmas Day_-1', 'Columbus Day_+0', 'Columbus Day_+1',\n", + " 'Columbus Day_-1', 'Memorial Day_+0', 'Memorial Day_+1',\n", + " 'Memorial Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [1.7690346240997314, -4.356875419616699, -2.5583579540252686,\n", - " 3.7520101070404053, 1.3547093868255615, -1.4862573146820068,\n", - " 4.024331092834473, -0.7799521684646606, -1.7819913625717163,\n", - " -2.080281972885132, 0.33075717091560364, 4.571771144866943,\n", - " 2.3425700664520264, 1.175431251525879, 2.4367449283599854,\n", - " -2.1346323490142822, 3.684549331665039, 0.6624831557273865,\n", - " -2.1663002967834473, -2.142958164215088, 5.068490505218506,\n", - " -0.09585778415203094, 2.920788288116455, 3.8810973167419434,\n", - " 0.36290690302848816, -1.381648063659668, 1.097022533416748,\n", - " 2.787872552871704, 1.5658684968948364, 1.4216945171356201],\n", + " 'y': [-0.5818150043487549, 1.41917085647583, -0.444327712059021,\n", + " -0.49767136573791504, 0.6451370716094971, -2.0712969303131104,\n", + " 0.7729998826980591, -0.3807811141014099, 2.18367862701416,\n", + " -2.0048325061798096, -0.09419155865907669, -0.629780113697052,\n", + " 1.8640313148498535, -4.3432936668396, -2.4195406436920166,\n", + " -0.7208139300346375, -1.694814920425415, 2.2192749977111816,\n", + " 1.3009229898452759, -3.072805881500244, -0.42180103063583374,\n", + " -2.2558043003082275, 1.0419955253601074, 6.730881690979004,\n", + " 1.1950528621673584, 1.6051987409591675, 0.9189518690109253,\n", + " -1.9215060472488403, 0.3892437815666199, -2.9500598907470703],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n", @@ -2500,7 +2980,7 @@ "})" ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } From db091008f9e9e161b53beb4b6cbdd80006fb46a1 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 21:33:48 -0700 Subject: [PATCH 25/39] set finding lr arg --- neuralprophet/forecaster.py | 10 +++++++--- neuralprophet/time_net.py | 12 +++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 57bb8eeb2..5f6c59de2 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -494,7 +494,7 @@ def __init__( log.info( DeprecationWarning( "Providing metrics to collect via `collect_metrics` in NeuralProphet is deprecated and will be " - + "removed in a future version. The metrics are now configure in the `fit()` method via `metrics`." + + "removed in a future version. The metrics are now configured in the `fit()` method via `metrics`." ) ) self.metrics = utils_metrics.get_metrics(collect_metrics) @@ -2812,9 +2812,10 @@ def _train( val_loader = self._init_val_loader(df_val) if not self.config_train.learning_rate: + # Find suitable learning rate # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - # Find suitable learning rate + self.model.finding_lr = True tuner = Tuner(self.trainer) lr_finder = tuner.lr_find( model=self.model, @@ -2825,6 +2826,7 @@ def _train( # Estimate the optimal learning rate from the loss curve assert lr_finder is not None _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) + self.model.finding_lr = False start = time.time() self.trainer.fit( self.model, @@ -2833,9 +2835,10 @@ def _train( ) else: if not self.config_train.learning_rate: + # Find suitable learning rate # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - # Find suitable learning rate + self.model.finding_lr = True tuner = Tuner(self.trainer) lr_finder = tuner.lr_find( model=self.model, @@ -2845,6 +2848,7 @@ def _train( assert lr_finder is not None # Estimate the optimal learning rate from the loss curve _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) + self.model.finding_lr = False start = time.time() self.trainer.fit( self.model, diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 4c2367e13..d7e6c3581 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -162,8 +162,9 @@ def __init__( self.automatic_optimization = False # Hyperparameters (can be tuned using trainer.tune()) - self.learning_rate = self.config_train.learning_rate if self.config_train.learning_rate is not None else 1e-3 + self.learning_rate = self.config_train.learning_rate self.batch_size = self.config_train.batch_size + self.finding_lr = False # flag to indicate if we are in lr finder mode # Metrics Config self.metrics_enabled = bool(metrics) # yields True if metrics is not an empty dictionary @@ -799,12 +800,13 @@ def training_step(self, batch, batch_idx): scheduler.step() # scheduler.step(epoch=self.train_progress) - # Manually track the loss for the lr finder - self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + if self.finding_lr: + # Manually track the loss for the lr finder + self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) # Metrics - if self.metrics_enabled: + if self.metrics_enabled and not self.finding_lr: predicted_denorm = self.denormalize(predicted[:, :, 0]) target_denorm = self.denormalize(targets.squeeze(dim=2)) self.log_dict(self.metrics_train(predicted_denorm, target_denorm), **self.log_args) From b8bf9b8ad15837f25db9129aaa5789b834b6c094 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 21:50:18 -0700 Subject: [PATCH 26/39] add logging of progress and lr --- neuralprophet/time_net.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index d7e6c3581..e5ea65d7b 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -797,8 +797,11 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step() - # scheduler.step(epoch=self.train_progress) + if self.config_train.scheduler == torch.optim.lr_scheduler.OneCycleLR: + # is configured with total_steps (not epochs) + scheduler.step() + else: + scheduler.step(epoch=self.train_progress) if self.finding_lr: # Manually track the loss for the lr finder @@ -812,6 +815,8 @@ def training_step(self, batch, batch_idx): self.log_dict(self.metrics_train(predicted_denorm, target_denorm), **self.log_args) self.log("Loss", loss, **self.log_args) self.log("RegLoss", reg_loss, **self.log_args) + self.log("TrainProgress", self.train_progress, **self.log_args) + self.log("LR", scheduler.get_last_lr()[0], **self.log_args) return loss def validation_step(self, batch, batch_idx): @@ -873,6 +878,8 @@ def configure_optimizers(self): self.config_train.set_scheduler() # Optimizer + if self.finding_lr and self.learning_rate is None: + self.learning_rate = self.config_train.lr_finder_args["min_lr"] optimizer = self.config_train.optimizer( self.parameters(), lr=self.learning_rate, From a87651cde5dc9a9829be24e39995485d7a840cdd Mon Sep 17 00:00:00 2001 From: ourownstory Date: Tue, 27 Aug 2024 22:45:52 -0700 Subject: [PATCH 27/39] update lr schedulers to use epochs --- neuralprophet/configure.py | 4 +- neuralprophet/forecaster.py | 2 +- neuralprophet/time_net.py | 17 +- tests/debug/debug-energy-price-hourly.ipynb | 1493 ++++++++++--------- 4 files changed, 802 insertions(+), 714 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index e447eae4e..ac6ea2330 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -222,7 +222,7 @@ def set_scheduler(self): self.scheduler = "exponentiallr" if isinstance(self.scheduler, str): - if self.scheduler.lower() == "onecyclelr": + if self.scheduler.lower() in ["onecycle", "onecyclelr"]: self.scheduler = torch.optim.lr_scheduler.OneCycleLR defaults = { "pct_start": 0.3, @@ -240,7 +240,7 @@ def set_scheduler(self): elif self.scheduler.lower() == "exponentiallr": self.scheduler = torch.optim.lr_scheduler.ExponentialLR defaults = { - "gamma": 0.95, + "gamma": 0.9, } elif self.scheduler.lower() == "cosineannealinglr": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 5f6c59de2..dae28e318 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -313,7 +313,7 @@ class NeuralProphet: Examples -------- >>> from neuralprophet import NeuralProphet - >>> m = NeuralProphet(scheduler="ExponentialLR", scheduler_args={"gamma": 0.99}) + >>> m = NeuralProphet(scheduler="ExponentialLR", scheduler_args={"gamma": 0.8}) COMMENT Uncertainty Estimation diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index e5ea65d7b..30fb7a56e 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -775,9 +775,8 @@ def loss_func(self, inputs, predicted, targets): def training_step(self, batch, batch_idx): inputs, targets, meta = batch - self.train_progress = ( - self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch) - ) / self.config_train.epochs + epoch_float = self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch) + self.train_progress = epoch_float / self.config_train.epochs # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) @@ -797,11 +796,7 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - if self.config_train.scheduler == torch.optim.lr_scheduler.OneCycleLR: - # is configured with total_steps (not epochs) - scheduler.step() - else: - scheduler.step(epoch=self.train_progress) + scheduler.step(epoch=epoch_float) if self.finding_lr: # Manually track the loss for the lr finder @@ -815,7 +810,7 @@ def training_step(self, batch, batch_idx): self.log_dict(self.metrics_train(predicted_denorm, target_denorm), **self.log_args) self.log("Loss", loss, **self.log_args) self.log("RegLoss", reg_loss, **self.log_args) - self.log("TrainProgress", self.train_progress, **self.log_args) + # self.log("TrainProgress", self.train_progress, **self.log_args) self.log("LR", scheduler.get_last_lr()[0], **self.log_args) return loss @@ -891,8 +886,8 @@ def configure_optimizers(self): lr_scheduler = self.config_train.scheduler( optimizer, max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - # total_steps=self.config_train.epochs, # if using self.lr_schedulers().step(epoch=self.train_progress) + # total_steps=self.trainer.estimated_stepping_batches, # if using self.lr_schedulers().step() + total_steps=self.config_train.epochs, # if using self.lr_schedulers().step(epoch=epoch_float) **self.config_train.scheduler_args, ) else: diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index d81254f4e..a8c769d20 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -169,7 +169,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -178,9 +178,6 @@ } ], "source": [ - "### Temporary Test for on-the-fly sampling - very time consuming!\n", - "\n", - "\n", "# Hyperparameter\n", "tuned_params = {\n", " \"n_lags\": 10,\n", @@ -189,13 +186,13 @@ " \"yearly_seasonality\": 10,\n", " \"weekly_seasonality\": True,\n", " \"daily_seasonality\": False, # due to conditional daily seasonality\n", - " \"batch_size\": 128,\n", + " \"batch_size\": 64,\n", " \"ar_layers\": [8, 4],\n", " \"lagged_reg_layers\": [8],\n", " # not tuned\n", " \"n_forecasts\": 5,\n", - " \"learning_rate\": 0.001,\n", - " \"epochs\": 10,\n", + " \"learning_rate\": 0.1,\n", + " \"epochs\": 20,\n", " \"trend_global_local\": \"global\",\n", " \"season_global_local\": \"global\",\n", " \"drop_missing\": True,\n", @@ -276,7 +273,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b5912ecccdbe4255b43d1751767fc1e8", + "model_id": "a3a2688119ad4f35babe4b5751d7a677", "version_major": 2, "version_minor": 0 }, @@ -302,7 +299,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4b649345cd794351a1f5b5c9ab11469f", + "model_id": "be32cb53d78b4ead975b12aa5ad15196", "version_major": 2, "version_minor": 0 }, @@ -317,16 +314,155 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 128. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 70. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:232: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n", + " warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9205ea516e5740db878a07ce3e1aa5c4", + "model_id": "62c6b389e61e4b0389a7494596342de1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | | 0/? [00:00Loss_val\n", " RegLoss_val\n", " epoch\n", - " train_loss\n", - " reg_loss\n", " MAE\n", " RMSE\n", " Loss\n", " RegLoss\n", + " LR\n", " \n", " \n", " \n", " \n", - " 9\n", - " 0.829808\n", - " 0.939124\n", - " 2.014619\n", + " 0\n", + " 0.499936\n", + " 0.583346\n", + " 0.938270\n", " 0.0\n", - " 9\n", - " 1.371581\n", + " 0\n", + " 1.503294\n", + " 2.114124\n", + " 1.916612\n", " 0.0\n", - " 1.053546\n", - " 1.346723\n", - " 1.37515\n", + " 0.004087\n", + " \n", + " \n", + " 1\n", + " 0.534045\n", + " 0.631530\n", + " 0.440998\n", " 0.0\n", + " 1\n", + " 0.718145\n", + " 0.943761\n", + " 0.505523\n", + " 0.0\n", + " 0.021600\n", " \n", - " \n", - "\n", - "" - ], - "text/plain": [ - " MAE_val RMSE_val Loss_val RegLoss_val epoch train_loss reg_loss \\\n", - "9 0.829808 0.939124 2.014619 0.0 9 1.371581 0.0 \n", + " \n", + " 2\n", + " 0.542755\n", + " 0.644081\n", + " 0.454675\n", + " 0.0\n", + " 2\n", + " 0.537536\n", + " 0.724863\n", + " 0.347341\n", + " 0.0\n", + " 0.050152\n", + " \n", + " \n", + " 3\n", + " 0.508438\n", + " 0.616892\n", + " 0.487512\n", + " 0.0\n", + " 3\n", + " 0.503906\n", + " 0.677358\n", + " 0.312197\n", + " 0.0\n", + " 0.078837\n", + " \n", + " \n", + " 4\n", + " 0.649246\n", + " 0.755430\n", + " 0.545550\n", + " 0.0\n", + " 4\n", + " 0.505661\n", + " 0.671692\n", + " 0.313073\n", + " 0.0\n", + " 0.096699\n", + " \n", + " \n", + " 5\n", + " 0.463848\n", + " 0.568442\n", + " 0.367994\n", + " 0.0\n", + " 5\n", + " 0.520102\n", + " 0.691615\n", + " 0.322044\n", + " 0.0\n", + " 0.099596\n", + " \n", + " \n", + " 6\n", + " 0.355072\n", + " 0.410634\n", + " 0.251356\n", + " 0.0\n", + " 6\n", + " 0.511964\n", + " 0.684423\n", + " 0.316359\n", + " 0.0\n", + " 0.097137\n", + " \n", + " \n", + " 7\n", + " 0.447367\n", + " 0.500184\n", + " 0.336913\n", + " 0.0\n", + " 7\n", + " 0.503181\n", + " 0.669457\n", + " 0.307173\n", + " 0.0\n", + " 0.092315\n", + " \n", + " \n", + " 8\n", + " 0.821846\n", + " 0.951728\n", + " 0.720978\n", + " 0.0\n", + " 8\n", + " 0.503031\n", + " 0.671102\n", + " 0.308114\n", + " 0.0\n", + " 0.085371\n", + " \n", + " \n", + " 9\n", + " 0.414638\n", + " 0.474769\n", + " 0.334302\n", + " 0.0\n", + " 9\n", + " 0.511291\n", + " 0.686945\n", + " 0.311271\n", + " 0.0\n", + " 0.076654\n", + " \n", + " \n", + " 10\n", + " 0.606577\n", + " 0.723609\n", + " 0.504883\n", + " 0.0\n", + " 10\n", + " 0.493725\n", + " 0.657971\n", + " 0.301624\n", + " 0.0\n", + " 0.066600\n", + " \n", + " \n", + " 11\n", + " 0.560590\n", + " 0.657100\n", + " 0.453766\n", + " 0.0\n", + " 11\n", + " 0.487225\n", + " 0.654937\n", + " 0.295672\n", + " 0.0\n", + " 0.055713\n", + " \n", + " \n", + " 12\n", + " 0.419592\n", + " 0.459256\n", + " 0.307631\n", + " 0.0\n", + " 12\n", + " 0.479861\n", + " 0.642756\n", + " 0.287683\n", + " 0.0\n", + " 0.044541\n", + " \n", + " \n", + " 13\n", + " 0.492459\n", + " 0.561360\n", + " 0.379794\n", + " 0.0\n", + " 13\n", + " 0.479290\n", + " 0.643680\n", + " 0.284241\n", + " 0.0\n", + " 0.033641\n", + " \n", + " \n", + " 14\n", + " 0.547214\n", + " 0.630017\n", + " 0.432885\n", + " 0.0\n", + " 14\n", + " 0.471661\n", + " 0.633883\n", + " 0.280081\n", + " 0.0\n", + " 0.023563\n", + " \n", + " \n", + " 15\n", + " 0.542842\n", + " 0.630828\n", + " 0.427475\n", + " 0.0\n", + " 15\n", + " 0.467507\n", + " 0.630942\n", + " 0.275439\n", + " 0.0\n", + " 0.014810\n", + " \n", + " \n", + " 16\n", + " 0.497607\n", + " 0.569062\n", + " 0.380621\n", + " 0.0\n", + " 16\n", + " 0.468031\n", + " 0.631191\n", + " 0.276560\n", + " 0.0\n", + " 0.007821\n", + " \n", + " \n", + " 17\n", + " 0.507053\n", + " 0.580275\n", + " 0.390214\n", + " 0.0\n", + " 17\n", + " 0.458170\n", + " 0.620013\n", + " 0.268218\n", + " 0.0\n", + " 0.002948\n", + " \n", + " \n", + " 18\n", + " 0.506170\n", + " 0.578457\n", + " 0.389007\n", + " 0.0\n", + " 18\n", + " 0.460292\n", + " 0.622816\n", + " 0.268188\n", + " 0.0\n", + " 0.000434\n", + " \n", + " \n", + " 19\n", + " 0.508543\n", + " 0.581377\n", + " 0.391374\n", + " 0.0\n", + " 19\n", + " 0.459247\n", + " 0.622094\n", + " 0.267627\n", + " 0.0\n", + " 0.000405\n", + " \n", + " \n", + "\n", + "" + ], + "text/plain": [ + " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", + "0 0.499936 0.583346 0.938270 0.0 0 1.503294 2.114124 \n", + "1 0.534045 0.631530 0.440998 0.0 1 0.718145 0.943761 \n", + "2 0.542755 0.644081 0.454675 0.0 2 0.537536 0.724863 \n", + "3 0.508438 0.616892 0.487512 0.0 3 0.503906 0.677358 \n", + "4 0.649246 0.755430 0.545550 0.0 4 0.505661 0.671692 \n", + "5 0.463848 0.568442 0.367994 0.0 5 0.520102 0.691615 \n", + "6 0.355072 0.410634 0.251356 0.0 6 0.511964 0.684423 \n", + "7 0.447367 0.500184 0.336913 0.0 7 0.503181 0.669457 \n", + "8 0.821846 0.951728 0.720978 0.0 8 0.503031 0.671102 \n", + "9 0.414638 0.474769 0.334302 0.0 9 0.511291 0.686945 \n", + "10 0.606577 0.723609 0.504883 0.0 10 0.493725 0.657971 \n", + "11 0.560590 0.657100 0.453766 0.0 11 0.487225 0.654937 \n", + "12 0.419592 0.459256 0.307631 0.0 12 0.479861 0.642756 \n", + "13 0.492459 0.561360 0.379794 0.0 13 0.479290 0.643680 \n", + "14 0.547214 0.630017 0.432885 0.0 14 0.471661 0.633883 \n", + "15 0.542842 0.630828 0.427475 0.0 15 0.467507 0.630942 \n", + "16 0.497607 0.569062 0.380621 0.0 16 0.468031 0.631191 \n", + "17 0.507053 0.580275 0.390214 0.0 17 0.458170 0.620013 \n", + "18 0.506170 0.578457 0.389007 0.0 18 0.460292 0.622816 \n", + "19 0.508543 0.581377 0.391374 0.0 19 0.459247 0.622094 \n", "\n", - " MAE RMSE Loss RegLoss \n", - "9 1.053546 1.346723 1.37515 0.0 " + " Loss RegLoss LR \n", + "0 1.916612 0.0 0.004087 \n", + "1 0.505523 0.0 0.021600 \n", + "2 0.347341 0.0 0.050152 \n", + "3 0.312197 0.0 0.078837 \n", + "4 0.313073 0.0 0.096699 \n", + "5 0.322044 0.0 0.099596 \n", + "6 0.316359 0.0 0.097137 \n", + "7 0.307173 0.0 0.092315 \n", + "8 0.308114 0.0 0.085371 \n", + "9 0.311271 0.0 0.076654 \n", + "10 0.301624 0.0 0.066600 \n", + "11 0.295672 0.0 0.055713 \n", + "12 0.287683 0.0 0.044541 \n", + "13 0.284241 0.0 0.033641 \n", + "14 0.280081 0.0 0.023563 \n", + "15 0.275439 0.0 0.014810 \n", + "16 0.276560 0.0 0.007821 \n", + "17 0.268218 0.0 0.002948 \n", + "18 0.268188 0.0 0.000434 \n", + "19 0.267627 0.0 0.000405 " ] }, "execution_count": 9, @@ -1833,7 +2285,7 @@ } ], "source": [ - "metrics.tail(1)" + "metrics" ] }, { @@ -1923,7 +2375,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "abd147333b4244fba6e0e860ab561ce6", + "model_id": "23b4bbc68cea4966bf719a33132a3726", "version_major": 2, "version_minor": 0 }, @@ -1948,7 +2400,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3a62bfa1b5104992b79f98302c2a67c0", + "model_id": "dd54d62acf464964b7d15974106128e3", "version_major": 2, "version_minor": 0 }, @@ -2034,7 +2486,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f01de7b4f30b431b97b3353b3d8c2686", + "model_id": "c873800851d442818fb758ee0b8565b0", "version_major": 2, "version_minor": 0 }, @@ -2045,61 +2497,61 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '0f472e8d-3931-40ec-842b-d3902dd90c7f',\n", + " 'uid': '1e41dda2-f7ca-4501-ae0d-394dbc69313f',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([-15.375717 , -35.083336 , -22.973457 , ..., -5.606529 , 7.961891 ,\n", - " 3.7938385], dtype=float32)},\n", + " 'y': array([ 7.0392876, 9.7315445, 17.582043 , ..., 48.029076 , 46.43782 ,\n", + " 48.867878 ], dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'bd2eba39-2e42-4d93-a783-a1a3a3642743',\n", + " 'uid': '8b1256ed-4f9c-4a39-98f6-d94c1af49272',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", + " 'y': array([64.257675, 64.03381 , 72.06434 , ..., 74.77048 , 70.81393 , 73.04162 ],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '478c4639-805c-421f-8836-0f1aa7be7e40',\n", + " 'uid': 'c9c42f86-4573-4aea-8a34-4980b55458a9',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", + " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '8e5d4660-de04-418d-a685-aa43e180cdcb',\n", + " 'uid': '8704218f-c879-46d2-98f8-70840910069f',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([31.923584, 29.487762, 39.78112 , ..., 55.727097, 52.426003, 39.000443],\n", + " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '39b3c783-2ac7-4e92-bb1a-d11d10d52e56',\n", + " 'uid': '91a4e9d7-9480-462b-b473-7f38c9371ea4',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -2251,387 +2703,22 @@ "\n", "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", - "\n", - "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", - "\n", - "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", - "\n", - "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", - "\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", - "\n", - "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", - "\n", "\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c3aa3b2182364855a6dd57d92d26b188", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "FigureWidgetResampler({\n", - " 'data': [{'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': '[R] Trend ~1h',\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '8afd9c2b-62cb-46cf-b9da-bd117f6aadc6',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x',\n", - " 'y': array([45.118847, 45.12286 , 45.126873, ..., 50.832302, 50.836315, 50.84434 ],\n", - " dtype=float32),\n", - " 'yaxis': 'y'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': 'd491b3bb-5ab7-4d51-8782-51fbe760fe55',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x2',\n", - " 'y': array([4.2724757, 4.277932 , 4.2831407, ..., 3.0124106, 2.9952903, 2.9674723],\n", - " dtype=float32),\n", - " 'yaxis': 'y2'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '2228f11e-ec27-44da-afe3-a7a367df4fb8',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x3',\n", - " 'y': array([-6.360024 , -6.2361574 , -6.04062 , ..., -0.17524393, -0.68096423,\n", - " -0.951675 ], dtype=float32),\n", - " 'yaxis': 'y3'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '6502f74d-bf1c-4c6e-89ce-33caec7a779d',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x4',\n", - " 'y': array([-0.6189856 , 1.0478094 , -0.94150484, ..., 0. , 0. ,\n", - " 0. ], dtype=float32),\n", - " 'yaxis': 'y4'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '41f9a971-74fd-4abe-879a-38b10a8e6354',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x5',\n", - " 'y': array([ 0. , 0. , 0. , ..., 2.742541, -12.349686,\n", - " -12.63406 ], dtype=float32),\n", - " 'yaxis': 'y5'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '683da2d1-526e-4dbd-81ca-db6d8026e942',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x6',\n", - " 'y': array([ -2.0847676, -11.245036 , -10.022209 , ..., -5.338226 , -4.647217 ,\n", - " -4.156648 ], dtype=float32),\n", - " 'yaxis': 'y6'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '21765c1c-4b8b-4e56-8cc9-e129bec394d5',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x7',\n", - " 'y': array([1.0558217, 2.7553616, 6.266397 , ..., 4.6533093, 9.819243 , 0. ],\n", - " dtype=float32),\n", - " 'yaxis': 'y7'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '089c5b7d-4f26-4ebe-91e1-0f4943efc2e3',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x8',\n", - " 'y': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),\n", - " 'yaxis': 'y8'},\n", - " {'fill': 'tozeroy',\n", - " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", - " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", - " 'mode': 'lines',\n", - " 'name': '[R] yhat5 1.0% ~1h',\n", - " 'showlegend': True,\n", - " 'type': 'scatter',\n", - " 'uid': '57a063db-b1cc-4868-8208-9a193411c395',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 13, 0),\n", - " datetime.datetime(2015, 3, 2, 14, 0),\n", - " datetime.datetime(2015, 3, 2, 16, 0)], dtype=object),\n", - " 'xaxis': 'x9',\n", - " 'y': array([-45.265923, -77.71274 , -55.634865, ..., -31.16211 , -29.599195,\n", - " -61.93689 ], dtype=float32),\n", - " 'yaxis': 'y9'},\n", - " {'fill': 'tozeroy',\n", - " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", - " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", - " 'mode': 'lines',\n", - " 'name': '[R] yhat5 99.0% ~1h',\n", - " 'showlegend': True,\n", - " 'type': 'scatter',\n", - " 'uid': 'bd0275b8-e4ef-4b65-a4dd-dcedabd36797',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 13, 0),\n", - " datetime.datetime(2015, 3, 2, 15, 0),\n", - " datetime.datetime(2015, 3, 2, 16, 0)], dtype=object),\n", - " 'xaxis': 'x9',\n", - " 'y': array([ 2.0333786, -13.141644 , 7.119713 , ..., 5.760723 , -2.4633446,\n", - " 2.6906433], dtype=float32),\n", - " 'yaxis': 'y9'}],\n", - " 'layout': {'autosize': True,\n", - " 'barmode': 'overlay',\n", - " 'font': {'size': 10},\n", - " 'height': 1890,\n", - " 'hovermode': 'x unified',\n", - " 'legend': {'traceorder': 'reversed', 'y': 0.1},\n", - " 'margin': {'b': 0, 'l': 0, 'pad': 0, 'r': 10, 't': 10},\n", - " 'template': '...',\n", - " 'title': {'font': {'size': 12}},\n", - " 'width': 700,\n", - " 'xaxis': {'anchor': 'y',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis2': {'anchor': 'y2',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis3': {'anchor': 'y3',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis4': {'anchor': 'y4',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis5': {'anchor': 'y5',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis6': {'anchor': 'y6',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis7': {'anchor': 'y7',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis8': {'anchor': 'y8',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis9': {'anchor': 'y9',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'yaxis': {'anchor': 'x',\n", - " 'domain': [0.9185185185185185, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'normal',\n", - " 'showline': True,\n", - " 'title': {'text': 'Trend'}},\n", - " 'yaxis2': {'anchor': 'x2',\n", - " 'domain': [0.8037037037037038, 0.8851851851851853],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'yearly seasonality'}},\n", - " 'yaxis3': {'anchor': 'x3',\n", - " 'domain': [0.6888888888888889, 0.7703703703703704],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'weekly seasonality'}},\n", - " 'yaxis4': {'anchor': 'x4',\n", - " 'domain': [0.5740740740740741, 0.6555555555555556],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'winter seasonality'}},\n", - " 'yaxis5': {'anchor': 'x5',\n", - " 'domain': [0.45925925925925926, 0.5407407407407407],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'summer seasonality'}},\n", - " 'yaxis6': {'anchor': 'x6',\n", - " 'domain': [0.34444444444444444, 0.42592592592592593],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'AR (5)-ahead'}},\n", - " 'yaxis7': {'anchor': 'x7',\n", - " 'domain': [0.22962962962962963, 0.3111111111111111],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Lagged Regressor \"temp\" (5)-ahead'}},\n", - " 'yaxis8': {'anchor': 'x8',\n", - " 'domain': [0.11481481481481481, 0.1962962962962963],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Additive Events'}},\n", - " 'yaxis9': {'anchor': 'x9',\n", - " 'domain': [0.0, 0.08148148148148149],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Uncertainty'}}}\n", - "})" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "ename": "IndexError", + "evalue": "index -1 is out of bounds for axis 0 with size 0", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_components\u001b[49m\u001b[43m(\u001b[49m\u001b[43mforecast\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2452\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2450\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2451\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2452\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2453\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2454\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2455\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2456\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2457\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2458\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2459\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2460\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2461\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2462\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2464\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2465\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2470\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2471\u001b[0m )\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:332\u001b[0m, in \u001b[0;36mplot_components\u001b[0;34m(m, fcst, plot_configuration, df_name, one_period_per_season, figsize, resampler_active, plotly_static)\u001b[0m\n\u001b[1;32m 327\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m get_forecast_component_props(\n\u001b[1;32m 328\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst, df_name\u001b[38;5;241m=\u001b[39mdf_name, comp_name\u001b[38;5;241m=\u001b[39mcomp_name, plot_name\u001b[38;5;241m=\u001b[39mcomp[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplot_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 329\u001b[0m )\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto-regression\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlagged regressor\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name:\n\u001b[0;32m--> 332\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m \u001b[43mget_multiforecast_component_props\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcomp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 333\u001b[0m fig\u001b[38;5;241m.\u001b[39mupdate_layout(barmode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moverlay\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:603\u001b[0m, in \u001b[0;36mget_multiforecast_component_props\u001b[0;34m(fcst, comp_name, plot_name, multiplicative, bar, focus, num_overplot, **kwargs)\u001b[0m\n\u001b[1;32m 601\u001b[0m y \u001b[38;5;241m=\u001b[39m fcst[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcomp_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 602\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mvalues\n\u001b[0;32m--> 603\u001b[0m \u001b[43my\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bar:\n\u001b[1;32m 605\u001b[0m traces\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 606\u001b[0m go\u001b[38;5;241m.\u001b[39mBar(\n\u001b[1;32m 607\u001b[0m name\u001b[38;5;241m=\u001b[39mplot_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 613\u001b[0m )\n\u001b[1;32m 614\u001b[0m )\n", + "\u001b[0;31mIndexError\u001b[0m: index -1 is out of bounds for axis 0 with size 0" + ] } ], "source": [ @@ -2640,7 +2727,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2672,7 +2759,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d00edbb989d64c0e8f4ad128abd5121d", + "model_id": "06ddcd4df2f94141aefe0f5056b5c617", "version_major": 2, "version_minor": 0 }, @@ -2683,18 +2770,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': 'c4a57947-5fce-4aa9-ac50-1dafd620b2e9',\n", + " 'uid': '2b2cb1c3-4a0e-4729-92d7-1167a376d732',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([44.986443, 50.663788], dtype=float32),\n", + " 'y': array([25.74136 , 14.885769], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': '2f8db94a-fecf-45cc-a9f3-264ac82ad7b7',\n", + " 'uid': '965409ff-c889-48d6-bae0-936574b46f88',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2702,15 +2789,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([4.156685 , 4.008716 , 3.7399096, ..., 3.741331 , 4.0178666, 4.1622863],\n", - " dtype=float32),\n", + " 'y': array([-48.182846, -50.157616, -51.64516 , ..., -38.81324 , -42.12695 ,\n", + " -45.080776], dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': 'f75097e7-bc0a-4be8-9663-d49b2a9be7e9',\n", + " 'uid': 'ba3b298a-be66-40c5-8264-ef5b716d06bf',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2724,118 +2811,124 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([ 1.391791 , 1.5258877 , 1.6792349 , 1.8562642 , 2.0394247 ,\n", - " 2.2404115 , 2.4551697 , 2.6826897 , 2.9358835 , 3.1815054 ,\n", - " 3.4400578 , 3.7049842 , 3.9879084 , 4.277275 , 4.5464516 ,\n", - " 4.8153987 , 5.07863 , 5.349201 , 5.5994563 , 5.8297772 ,\n", - " 6.0581393 , 6.244349 , 6.4293222 , 6.584266 , 6.7048225 ,\n", - " 6.796993 , 6.861701 , 6.8956547 , 6.895379 , 6.8557863 ,\n", - " 6.78101 , 6.675751 , 6.522954 , 6.3411875 , 6.124332 ,\n", - " 5.8741636 , 5.589541 , 5.257053 , 4.9091725 , 4.538206 ,\n", - " 4.1433353 , 3.7214267 , 3.2801328 , 2.829382 , 2.3595853 ,\n", - " 1.8662068 , 1.3673176 , 0.89529383, 0.41826385, -0.06061759,\n", - " -0.5093681 , -0.98713917, -1.4474032 , -1.8622016 , -2.2607856 ,\n", - " -2.6188028 , -2.9754944 , -3.2915318 , -3.5731025 , -3.8194983 ,\n", - " -4.0445375 , -4.230131 , -4.3740654 , -4.4832726 , -4.552248 ,\n", - " -4.598398 , -4.600944 , -4.570165 , -4.5096326 , -4.415453 ,\n", - " -4.2919707 , -4.1594696 , -3.9919715 , -3.8121579 , -3.6086745 ,\n", - " -3.3993592 , -3.1723168 , -2.9296591 , -2.6987505 , -2.4613705 ,\n", - " -2.2308369 , -2.0005875 , -1.7844186 , -1.5757424 , -1.3716754 ,\n", - " -1.1903827 , -1.0321187 , -0.8927177 , -0.77494335, -0.6736823 ,\n", - " -0.600088 , -0.5510596 , -0.5266437 , -0.52750367, -0.55236495,\n", - " -0.59258753, -0.6633728 , -0.75148875, -0.8628553 , -0.9960426 ,\n", - " -1.140825 , -1.2935942 , -1.4584701 , -1.6384959 , -1.8310778 ,\n", - " -2.0071678 , -2.199288 , -2.383454 , -2.5767555 , -2.763159 ,\n", - " -2.9236434 , -3.0845432 , -3.2274694 , -3.3643992 , -3.4772263 ,\n", - " -3.573151 , -3.6550226 , -3.71116 , -3.7483535 , -3.7640781 ,\n", - " -3.759362 , -3.732831 , -3.6836677 , -3.6116889 , -3.5259597 ,\n", - " -3.4178932 , -3.295802 , -3.152984 , -2.9906476 , -2.8273013 ,\n", - " -2.6527996 , -2.4705598 , -2.267154 , -2.0591342 , -1.8623275 ,\n", - " -1.6648436 , -1.4681892 , -1.2616874 , -1.0673261 , -0.88591146,\n", - " -0.7106802 , -0.5368361 , -0.37239236, -0.22490174, -0.0935105 ,\n", - " 0.02904331, 0.1401226 , 0.24383257, 0.33426642, 0.41001436,\n", - " 0.47123942, 0.52750564, 0.57017094, 0.6104162 , 0.6333694 ,\n", - " 0.66995436, 0.69541276, 0.7135638 , 0.74077713, 0.7568518 ,\n", - " 0.79820424, 0.83649653, 0.87729394, 0.9278828 , 0.98950917,\n", - " 1.0756317 , 1.1707866 , 1.2736064 ], dtype=float32),\n", + " 'y': array([-3.04588795e+01, -3.54400597e+01, -3.99797134e+01, -4.43526154e+01,\n", + " -4.81006012e+01, -5.13112793e+01, -5.40508690e+01, -5.63052025e+01,\n", + " -5.81308632e+01, -5.93055077e+01, -6.00412445e+01, -6.02870750e+01,\n", + " -5.99476738e+01, -5.91321754e+01, -5.79114304e+01, -5.63154335e+01,\n", + " -5.43354149e+01, -5.17919426e+01, -4.89967918e+01, -4.60626144e+01,\n", + " -4.26569977e+01, -3.94107780e+01, -3.56459961e+01, -3.19178200e+01,\n", + " -2.82482281e+01, -2.45639687e+01, -2.09295406e+01, -1.71104145e+01,\n", + " -1.35622263e+01, -1.01446161e+01, -7.01516199e+00, -4.13522530e+00,\n", + " -1.30050945e+00, 1.12308824e+00, 3.24085855e+00, 5.06825876e+00,\n", + " 6.57082605e+00, 7.81317091e+00, 8.63896179e+00, 9.19974041e+00,\n", + " 9.46362782e+00, 9.35620880e+00, 8.95828724e+00, 8.39426994e+00,\n", + " 7.48970175e+00, 6.47928476e+00, 5.27563381e+00, 3.69769454e+00,\n", + " 2.23112702e+00, 5.54561198e-01, -9.77979958e-01, -2.75594759e+00,\n", + " -4.51804447e+00, -6.08968639e+00, -7.68710136e+00, -9.01867199e+00,\n", + " -1.04195356e+01, -1.15890474e+01, -1.25064888e+01, -1.32924528e+01,\n", + " -1.37729158e+01, -1.40936012e+01, -1.41434650e+01, -1.39483709e+01,\n", + " -1.34957485e+01, -1.28450623e+01, -1.19179354e+01, -1.07594900e+01,\n", + " -9.40466499e+00, -7.81799698e+00, -6.03295660e+00, -4.35296869e+00,\n", + " -2.36515474e+00, -4.06748503e-01, 1.75163662e+00, 3.73803329e+00,\n", + " 5.86639977e+00, 7.98941374e+00, 9.86133385e+00, 1.17813129e+01,\n", + " 1.33550673e+01, 1.49084988e+01, 1.61548271e+01, 1.73153419e+01,\n", + " 1.81719952e+01, 1.87052574e+01, 1.89303913e+01, 1.88749886e+01,\n", + " 1.85987663e+01, 1.79191532e+01, 1.69609566e+01, 1.57270803e+01,\n", + " 1.42448225e+01, 1.23795385e+01, 1.02109413e+01, 8.01280880e+00,\n", + " 5.49813223e+00, 2.89395428e+00, 5.37771396e-02, -3.00232887e+00,\n", + " -6.02308273e+00, -8.95809174e+00, -1.18841352e+01, -1.48639030e+01,\n", + " -1.78287868e+01, -2.03142662e+01, -2.28633080e+01, -2.51014042e+01,\n", + " -2.72459736e+01, -2.90301781e+01, -3.02393055e+01, -3.12937450e+01,\n", + " -3.19238663e+01, -3.21995316e+01, -3.19541264e+01, -3.12862282e+01,\n", + " -3.01354771e+01, -2.86372204e+01, -2.65992775e+01, -2.41690311e+01,\n", + " -2.13332558e+01, -1.81720409e+01, -1.45338621e+01, -1.03474331e+01,\n", + " -6.08415222e+00, -1.40853214e+00, 3.29980421e+00, 8.35535431e+00,\n", + " 1.37836151e+01, 1.89796047e+01, 2.41642418e+01, 2.93892212e+01,\n", + " 3.48417130e+01, 4.02113190e+01, 4.50756607e+01, 4.96866379e+01,\n", + " 5.40651703e+01, 5.83653717e+01, 6.21502800e+01, 6.54130478e+01,\n", + " 6.82579117e+01, 7.07123795e+01, 7.26812973e+01, 7.40298157e+01,\n", + " 7.47914352e+01, 7.49953308e+01, 7.46389465e+01, 7.36509781e+01,\n", + " 7.20302200e+01, 6.99344711e+01, 6.73451157e+01, 6.42106323e+01,\n", + " 6.03940163e+01, 5.62932777e+01, 5.17331505e+01, 4.69316330e+01,\n", + " 4.13555679e+01, 3.55166130e+01, 2.97218113e+01, 2.36869488e+01,\n", + " 1.76951599e+01, 1.11073389e+01, 4.76392841e+00, -1.35005784e+00,\n", + " -7.75571394e+00, -1.36067734e+01, -1.97326603e+01, -2.53451424e+01],\n", + " dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '30ca1705-d912-4c2c-9e4d-6bc89a92c81a',\n", + " 'uid': '6deeeb47-3be1-4b89-a492-0e6bf00cdd53',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([ -9.994224 , -9.253917 , -8.485407 , ..., -11.398036 , -11.11813 ,\n", - " -10.5702505], dtype=float32),\n", + " 'y': array([-0.43253064, 0.38781527, 1.1857711 , ..., -2.2786856 , -1.8861564 ,\n", + " -1.1711025 ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': '217a32cc-0d50-4274-9fda-4c86f2e9946c',\n", + " 'uid': 'b5b32b83-f643-4973-8db7-1797073f910f',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([-4.8032146, -4.148358 , -3.5782793, ..., -5.6449375, -5.535998 ,\n", - " -5.2187395], dtype=float32),\n", + " 'y': array([-21.528275, -22.261412, -22.15886 , ..., -17.52789 , -18.709206,\n", + " -20.179296], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': '98c7beb5-f685-4279-a089-5f435861c65c',\n", + " 'uid': 'dea67a1e-e499-4e43-a97e-0fb70895854a',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([ 0.43352383, 0.91227597, -0.95198816, -0.40500033, -0.20690662,\n", - " -0.5716205 , -0.04570548, -0.60042465, 0.64957386, -0.23974809],\n", - " dtype=float32),\n", + " 'y': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': 'a6f7f1dc-e6d7-4e80-88a7-d9bd4eb5a407',\n", + " 'uid': '78f9e1a7-125b-4563-aaec-5ff3fe054d54',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 0.1342093 , 0.03340701, -0.3030344 , -0.17586076, -0.09225646,\n", - " 0.13190955, -0.09886458, -0.17618474, -0.17881632, 0.07780991,\n", - " 0.05573834, -0.04498276, 0.02436313, -0.10179332, 0.10800336,\n", - " -0.04481242, -0.06880342, 0.24861242, -0.12319569, 0.40785006,\n", - " -0.17237887, -0.2162305 , 0.1412588 , 0.00572657, 0.04435053,\n", - " -0.25764507, 0.22107005, -0.01027014, -0.14096323, 0.0986762 ,\n", - " -0.09228688, 0.45199636, 0.23324972], dtype=float32),\n", + " 'y': array([ 0.4898405 , 1.880786 , 2.1640918 , -0.2928444 , 0.86554664,\n", + " -0.08299019, 0.70314807, 1.0031208 , 0.90763193, 0.4682139 ,\n", + " -1.1896216 , -1.6901426 , -1.2109685 , 0.9098389 , -0.64848685,\n", + " 0.9634216 , 0.91694885, 2.0049295 , 2.8199239 , 0.83436155,\n", + " 1.654415 , 2.4778936 , 0.64203995, 2.3134997 , 1.7692485 ,\n", + " 1.1947386 , 0.62410027, 0.79597855, 2.8871663 , 0.6992378 ,\n", + " 0.69474286, 1.997743 , 2.4678695 ], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': '8b6d6ad4-0b5c-406c-b3ca-d4ac4c5bb6fe',\n", + " 'uid': 'df559219-3f3e-45df-864c-8eebbaeadb67',\n", " 'width': 0.8,\n", - " 'x': array(['Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", + " 'x': array(['Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1', 'Veterans Day_+0',\n", + " 'Veterans Day_+1', 'Veterans Day_-1', 'Martin Luther King Jr. Day_+0',\n", + " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1',\n", " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", " \"Washington's Birthday_-1\", 'Independence Day_+0',\n", - " 'Independence Day_+1', 'Independence Day_-1',\n", - " 'Martin Luther King Jr. Day_+0', 'Martin Luther King Jr. Day_+1',\n", - " 'Martin Luther King Jr. Day_-1', 'Thanksgiving_+0', 'Thanksgiving_+1',\n", - " 'Thanksgiving_-1', 'Christmas Day_+0', 'Christmas Day_+1',\n", - " 'Christmas Day_-1', 'Columbus Day_+0', 'Columbus Day_+1',\n", - " 'Columbus Day_-1', 'Memorial Day_+0', 'Memorial Day_+1',\n", - " 'Memorial Day_-1'], dtype=object),\n", + " 'Independence Day_+1', 'Independence Day_-1', 'Memorial Day_+0',\n", + " 'Memorial Day_+1', 'Memorial Day_-1', 'Columbus Day_+0',\n", + " 'Columbus Day_+1', 'Columbus Day_-1', 'Thanksgiving_+0',\n", + " 'Thanksgiving_+1', 'Thanksgiving_-1', 'Christmas Day_+0',\n", + " 'Christmas Day_+1', 'Christmas Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [-0.5818150043487549, 1.41917085647583, -0.444327712059021,\n", - " -0.49767136573791504, 0.6451370716094971, -2.0712969303131104,\n", - " 0.7729998826980591, -0.3807811141014099, 2.18367862701416,\n", - " -2.0048325061798096, -0.09419155865907669, -0.629780113697052,\n", - " 1.8640313148498535, -4.3432936668396, -2.4195406436920166,\n", - " -0.7208139300346375, -1.694814920425415, 2.2192749977111816,\n", - " 1.3009229898452759, -3.072805881500244, -0.42180103063583374,\n", - " -2.2558043003082275, 1.0419955253601074, 6.730881690979004,\n", - " 1.1950528621673584, 1.6051987409591675, 0.9189518690109253,\n", - " -1.9215060472488403, 0.3892437815666199, -2.9500598907470703],\n", + " 'y': [-6.505150318145752, 3.078960418701172, -3.0603911876678467,\n", + " -1.937178373336792, -0.9162442684173584, -3.922412395477295,\n", + " -43.94681930541992, 48.15086364746094, -47.38690948486328,\n", + " 0.037264175713062286, 4.726099967956543, 0.49383774399757385,\n", + " -0.8578076362609863, -8.193577766418457, 8.767333030700684,\n", + " -2.2701916694641113, -0.19705480337142944, -1.0239486694335938,\n", + " 2.4958767890930176, 5.431707859039307, -3.5964465141296387,\n", + " -3.9246764183044434, 3.1682686805725098, 1.5535764694213867,\n", + " -3.401339054107666, 0.7919614911079407, 1.1661392450332642,\n", + " 0.8668169975280762, -6.069958686828613, -1.4564253091812134],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n", From b79b7e109508399c8f8cf05d2454cf6491095835 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Wed, 28 Aug 2024 18:00:38 -0700 Subject: [PATCH 28/39] fix lr-finder --- neuralprophet/configure.py | 5 +- neuralprophet/forecaster.py | 77 +- neuralprophet/time_net.py | 11 +- neuralprophet/utils.py | 39 +- tests/debug/debug-energy-price-hourly.ipynb | 995 +++++++------------- 5 files changed, 401 insertions(+), 726 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index ac6ea2330..467922213 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -279,10 +279,11 @@ def set_lr_finder_args(self, dataset_size, num_batches): # num_training = num_batches self.lr_finder_args.update( { - "min_lr": 1e-7, - "max_lr": 10, + "min_lr": 1e-8, + "max_lr": 1e1, "num_training": num_training, "early_stop_threshold": None, + "mode": "exponential", } ) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index f0f7f1b36..746549462 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2796,7 +2796,10 @@ def _train( # Set up data the training dataloader df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be removed? train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(df) # train_loader.dataset + dataset_size = len(train_loader.dataset) # df + batches_per_epoch = len(train_loader) + log.info(f"Dataset size: {dataset_size}") + log.info(f"Number of batches per training epoch: {batches_per_epoch}") # Internal flag to check if validation is enabled validation_enabled = df_val is not None @@ -2818,55 +2821,41 @@ def _train( deterministic=deterministic, ) + # Find suitable learning rate + if not self.config_train.learning_rate: + log.info("No Learning Rate provided. Activating learning rate finder") + # Set parameters for the learning rate finder + self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=batches_per_epoch) + log.info(f"Learning rate finder ---- ARGs: {self.config_train.lr_finder_args}") + self.model.finding_lr = True + tuner = Tuner(self.trainer) + lr_finder = tuner.lr_find( + model=self.model, + train_dataloaders=train_loader, + # val_dataloaders=val_loader, # not used, but may lead to Lightning bug if not provided + **self.config_train.lr_finder_args, + ) + # Estimate the optimal learning rate from the loss curve + assert lr_finder is not None + _, _, lr_suggested = utils.smooth_loss_and_suggest(lr_finder) + self.model.learning_rate = lr_suggested + self.config_train.learning_rate = lr_suggested + log.info(f"Learning rate finder suggested learning rate: {lr_suggested}") + self.model.finding_lr = False + # Tune hyperparams and train if validation_enabled: # Set up data the validation dataloader df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) val_loader = self._init_val_loader(df_val) - if not self.config_train.learning_rate: - # Find suitable learning rate - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - self.model.finding_lr = True - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - # val_dataloaders=val_loader, # not be used, but may lead to Lightning bug if not provided - **self.config_train.lr_finder_args, - ) - # Estimate the optimal learning rate from the loss curve - assert lr_finder is not None - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - self.model.finding_lr = False - start = time.time() - self.trainer.fit( - self.model, - train_loader, - val_loader, - ) - else: - if not self.config_train.learning_rate: - # Find suitable learning rate - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - self.model.finding_lr = True - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - **self.config_train.lr_finder_args, - ) - assert lr_finder is not None - # Estimate the optimal learning rate from the loss curve - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - self.model.finding_lr = False - start = time.time() - self.trainer.fit( - self.model, - train_loader, - ) + self.model.finding_lr = False + start = time.time() + self.trainer.fit( + model=self.model, + train_dataloaders=train_loader, + val_dataloaders=val_loader if validation_enabled else None, + ) log.debug("Train Time: {:8.3f}".format(time.time() - start)) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 30fb7a56e..a1148aa4f 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -775,8 +775,8 @@ def loss_func(self, inputs, predicted, targets): def training_step(self, batch, batch_idx): inputs, targets, meta = batch - epoch_float = self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch) - self.train_progress = epoch_float / self.config_train.epochs + epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch) + self.train_progress = epoch_float / float(self.config_train.epochs) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) @@ -796,7 +796,10 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step(epoch=epoch_float) + if self.finding_lr: + scheduler.step() + else: + scheduler.step(epoch=epoch_float) if self.finding_lr: # Manually track the loss for the lr finder @@ -874,7 +877,7 @@ def configure_optimizers(self): # Optimizer if self.finding_lr and self.learning_rate is None: - self.learning_rate = self.config_train.lr_finder_args["min_lr"] + self.learning_rate = 0.1 optimizer = self.config_train.optimizer( self.parameters(), lr=self.learning_rate, diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index cc5a3ed16..309d9098f 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -771,17 +771,17 @@ def smooth_loss_and_suggest(lr_finder, window=10): """ lr_finder_results = lr_finder.results lr = lr_finder_results["lr"] - loss = lr_finder_results["loss"] + loss = np.array(lr_finder_results["loss"]) # Derive window size from num lr searches, ensure window is divisible by 2 # half_window = math.ceil(round(len(loss) * 0.1) / 2) half_window = math.ceil(window / 2) # Pad sequence and initialialize hamming filter - loss = np.pad(np.array(loss), pad_width=half_window, mode="edge") - window = np.hamming(half_window * 2) + loss = np.pad(loss, pad_width=half_window, mode="edge") + hamming_window = np.hamming(2 * half_window) # Convolve the over the loss distribution try: - loss = np.convolve( - window / window.sum(), + loss_smooth = np.convolve( + hamming_window / hamming_window.sum(), loss, mode="valid", )[1:] @@ -790,26 +790,41 @@ def smooth_loss_and_suggest(lr_finder, window=10): f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " f"{window}." ) + # Suggest the lr with steepest negative gradient try: # Find the steepest gradient and the minimum loss after that - suggestion = lr[np.argmin(np.gradient(loss))] + suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] + suggestion_minimum = lr[np.argmin(loss_smooth)] except ValueError: log.error( f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " "samples or manually set the learning rate." ) raise - suggestion_default = lr_finder.suggestion(skip_begin=10, skip_end=3) - if suggestion is not None and suggestion_default is not None: - log_suggestion_smooth = np.log(suggestion) + # get the tuner's default suggestion + suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10) + + log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") + log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") + log.info(f"Learning rate finder ---- minimum: {suggestion_minimum}") + if suggestion_steepest is not None and suggestion_minimum is not None and suggestion_default is not None: + log_suggestion_smooth = np.log(suggestion_steepest) + log_suggestion_minimum = np.log(suggestion_minimum) log_suggestion_default = np.log(suggestion_default) - lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) - elif suggestion is None and suggestion_default is None: + lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_minimum + log_suggestion_default) / 3) + log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") + elif suggestion_steepest is None and suggestion_default is None: log.error("Automatic learning rate test failed. Please set manually the learning rate.") raise else: - lr_suggestion = suggestion if suggestion is not None else suggestion_default + lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default + + log.info(f"Learning rate finder ---- returning: {lr_suggestion}") + log.info(f"Learning rate finder ---- LR (start): {lr[:5]}") + log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}") + log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}") + log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}") return (loss, lr, lr_suggestion) diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index a8c769d20..ab4485f1f 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -16,7 +16,9 @@ "from plotly.subplots import make_subplots\n", "from plotly_resampler import unregister_plotly_resampler\n", "\n", - "from neuralprophet import NeuralProphet, set_random_seed" + "from neuralprophet import NeuralProphet, set_random_seed, set_log_level\n", + "\n", + "set_log_level(\"INFO\")" ] }, { @@ -169,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -186,13 +188,13 @@ " \"yearly_seasonality\": 10,\n", " \"weekly_seasonality\": True,\n", " \"daily_seasonality\": False, # due to conditional daily seasonality\n", - " \"batch_size\": 64,\n", + " \"batch_size\": 32,\n", " \"ar_layers\": [8, 4],\n", " \"lagged_reg_layers\": [8],\n", " # not tuned\n", " \"n_forecasts\": 5,\n", - " \"learning_rate\": 0.1,\n", - " \"epochs\": 20,\n", + " # \"learning_rate\": 0.1,\n", + " \"epochs\": 10,\n", " \"trend_global_local\": \"global\",\n", " \"season_global_local\": \"global\",\n", " \"drop_missing\": True,\n", @@ -239,6 +241,7 @@ "output_type": "stream", "text": [ "INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.\n", + "WARNING - (NP.forecaster.fit) - Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD\n", "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", "\n", @@ -267,13 +270,15 @@ "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", "\n", + "INFO - (NP.forecaster._train) - Dataset size: 2758\n", + "INFO - (NP.forecaster._train) - Number of batches per training epoch: 87\n", "INFO - (NP.utils.configure_trainer) - Using accelerator cpu with 1 device(s).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a3a2688119ad4f35babe4b5751d7a677", + "model_id": "3c6d261e96524335a24f00923ad36c02", "version_major": 2, "version_minor": 0 }, @@ -288,23 +293,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", - "\n" + "INFO - (NP.forecaster._train) - No Learning Rate provided. Activating learning rate finder\n", + "WARNING - (NP.config.set_lr_finder_args) - Learning rate finder: The number of batches (87) is too small than the required number for the learning rate finder (168). The results might not be optimal.\n", + "INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-08, 'max_lr': 10.0, 'num_training': 168, 'early_stop_threshold': None, 'mode': 'exponential'}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "be32cb53d78b4ead975b12aa5ad15196", + "model_id": "35957cbfc5044d2eab5eb3fe1ccee7c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training: | | 0/? [00:00\n", " \n", " 0\n", - " 0.499936\n", - " 0.583346\n", - " 0.938270\n", + " 0.710609\n", + " 0.819329\n", + " 0.622614\n", " 0.0\n", " 0\n", - " 1.503294\n", - " 2.114124\n", - " 1.916612\n", + " 1.144053\n", + " 1.562842\n", + " 1.368656\n", " 0.0\n", - " 0.004087\n", + " 0.012448\n", " \n", " \n", " 1\n", - " 0.534045\n", - " 0.631530\n", - " 0.440998\n", + " 0.836989\n", + " 0.946932\n", + " 0.733583\n", " 0.0\n", " 1\n", - " 0.718145\n", - " 0.943761\n", - " 0.505523\n", + " 0.532974\n", + " 0.702971\n", + " 0.343495\n", " 0.0\n", - " 0.021600\n", + " 0.039781\n", " \n", " \n", " 2\n", - " 0.542755\n", - " 0.644081\n", - " 0.454675\n", + " 0.588745\n", + " 0.704277\n", + " 0.497290\n", " 0.0\n", " 2\n", - " 0.537536\n", - " 0.724863\n", - " 0.347341\n", + " 0.495191\n", + " 0.658636\n", + " 0.304316\n", " 0.0\n", - " 0.050152\n", + " 0.040028\n", " \n", " \n", " 3\n", - " 0.508438\n", - " 0.616892\n", - " 0.487512\n", + " 0.699847\n", + " 0.818369\n", + " 0.594933\n", " 0.0\n", " 3\n", - " 0.503906\n", - " 0.677358\n", - " 0.312197\n", + " 0.475755\n", + " 0.632438\n", + " 0.283402\n", " 0.0\n", - " 0.078837\n", + " 0.012695\n", " \n", " \n", " 4\n", - " 0.649246\n", - " 0.755430\n", - " 0.545550\n", + " 0.704670\n", + " 0.828111\n", + " 0.594259\n", " 0.0\n", " 4\n", - " 0.505661\n", - " 0.671692\n", - " 0.313073\n", + " 0.460465\n", + " 0.615198\n", + " 0.271323\n", " 0.0\n", - " 0.096699\n", + " 0.004634\n", " \n", " \n", " 5\n", - " 0.463848\n", - " 0.568442\n", - " 0.367994\n", + " 0.648891\n", + " 0.755905\n", + " 0.530240\n", " 0.0\n", " 5\n", - " 0.520102\n", - " 0.691615\n", - " 0.322044\n", + " 0.458983\n", + " 0.614499\n", + " 0.270039\n", " 0.0\n", - " 0.099596\n", + " 0.003871\n", " \n", " \n", " 6\n", - " 0.355072\n", - " 0.410634\n", - " 0.251356\n", + " 0.715093\n", + " 0.839661\n", + " 0.608006\n", " 0.0\n", " 6\n", - " 0.511964\n", - " 0.684423\n", - " 0.316359\n", + " 0.459262\n", + " 0.614916\n", + " 0.269727\n", " 0.0\n", - " 0.097137\n", + " 0.002631\n", " \n", " \n", " 7\n", - " 0.447367\n", - " 0.500184\n", - " 0.336913\n", + " 0.689927\n", + " 0.807763\n", + " 0.577745\n", " 0.0\n", " 7\n", - " 0.503181\n", - " 0.669457\n", - " 0.307173\n", + " 0.455545\n", + " 0.609835\n", + " 0.266736\n", " 0.0\n", - " 0.092315\n", + " 0.001389\n", " \n", " \n", " 8\n", - " 0.821846\n", - " 0.951728\n", - " 0.720978\n", + " 0.646029\n", + " 0.753095\n", + " 0.530006\n", " 0.0\n", " 8\n", - " 0.503031\n", - " 0.671102\n", - " 0.308114\n", + " 0.457110\n", + " 0.611970\n", + " 0.267934\n", " 0.0\n", - " 0.085371\n", + " 0.000618\n", " \n", " \n", " 9\n", - " 0.414638\n", - " 0.474769\n", - " 0.334302\n", + " 0.688786\n", + " 0.806359\n", + " 0.577159\n", " 0.0\n", " 9\n", - " 0.511291\n", - " 0.686945\n", - " 0.311271\n", + " 0.456113\n", + " 0.611694\n", + " 0.267121\n", " 0.0\n", - " 0.076654\n", - " \n", - " \n", - " 10\n", - " 0.606577\n", - " 0.723609\n", - " 0.504883\n", - " 0.0\n", - " 10\n", - " 0.493725\n", - " 0.657971\n", - " 0.301624\n", - " 0.0\n", - " 0.066600\n", - " \n", - " \n", - " 11\n", - " 0.560590\n", - " 0.657100\n", - " 0.453766\n", - " 0.0\n", - " 11\n", - " 0.487225\n", - " 0.654937\n", - " 0.295672\n", - " 0.0\n", - " 0.055713\n", - " \n", - " \n", - " 12\n", - " 0.419592\n", - " 0.459256\n", - " 0.307631\n", - " 0.0\n", - " 12\n", - " 0.479861\n", - " 0.642756\n", - " 0.287683\n", - " 0.0\n", - " 0.044541\n", - " \n", - " \n", - " 13\n", - " 0.492459\n", - " 0.561360\n", - " 0.379794\n", - " 0.0\n", - " 13\n", - " 0.479290\n", - " 0.643680\n", - " 0.284241\n", - " 0.0\n", - " 0.033641\n", - " \n", - " \n", - " 14\n", - " 0.547214\n", - " 0.630017\n", - " 0.432885\n", - " 0.0\n", - " 14\n", - " 0.471661\n", - " 0.633883\n", - " 0.280081\n", - " 0.0\n", - " 0.023563\n", - " \n", - " \n", - " 15\n", - " 0.542842\n", - " 0.630828\n", - " 0.427475\n", - " 0.0\n", - " 15\n", - " 0.467507\n", - " 0.630942\n", - " 0.275439\n", - " 0.0\n", - " 0.014810\n", - " \n", - " \n", - " 16\n", - " 0.497607\n", - " 0.569062\n", - " 0.380621\n", - " 0.0\n", - " 16\n", - " 0.468031\n", - " 0.631191\n", - " 0.276560\n", - " 0.0\n", - " 0.007821\n", - " \n", - " \n", - " 17\n", - " 0.507053\n", - " 0.580275\n", - " 0.390214\n", - " 0.0\n", - " 17\n", - " 0.458170\n", - " 0.620013\n", - " 0.268218\n", - " 0.0\n", - " 0.002948\n", - " \n", - " \n", - " 18\n", - " 0.506170\n", - " 0.578457\n", - " 0.389007\n", - " 0.0\n", - " 18\n", - " 0.460292\n", - " 0.622816\n", - " 0.268188\n", - " 0.0\n", - " 0.000434\n", - " \n", - " \n", - " 19\n", - " 0.508543\n", - " 0.581377\n", - " 0.391374\n", - " 0.0\n", - " 19\n", - " 0.459247\n", - " 0.622094\n", - " 0.267627\n", - " 0.0\n", - " 0.000405\n", + " 0.000613\n", " \n", " \n", "\n", "" ], "text/plain": [ - " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", - "0 0.499936 0.583346 0.938270 0.0 0 1.503294 2.114124 \n", - "1 0.534045 0.631530 0.440998 0.0 1 0.718145 0.943761 \n", - "2 0.542755 0.644081 0.454675 0.0 2 0.537536 0.724863 \n", - "3 0.508438 0.616892 0.487512 0.0 3 0.503906 0.677358 \n", - "4 0.649246 0.755430 0.545550 0.0 4 0.505661 0.671692 \n", - "5 0.463848 0.568442 0.367994 0.0 5 0.520102 0.691615 \n", - "6 0.355072 0.410634 0.251356 0.0 6 0.511964 0.684423 \n", - "7 0.447367 0.500184 0.336913 0.0 7 0.503181 0.669457 \n", - "8 0.821846 0.951728 0.720978 0.0 8 0.503031 0.671102 \n", - "9 0.414638 0.474769 0.334302 0.0 9 0.511291 0.686945 \n", - "10 0.606577 0.723609 0.504883 0.0 10 0.493725 0.657971 \n", - "11 0.560590 0.657100 0.453766 0.0 11 0.487225 0.654937 \n", - "12 0.419592 0.459256 0.307631 0.0 12 0.479861 0.642756 \n", - "13 0.492459 0.561360 0.379794 0.0 13 0.479290 0.643680 \n", - "14 0.547214 0.630017 0.432885 0.0 14 0.471661 0.633883 \n", - "15 0.542842 0.630828 0.427475 0.0 15 0.467507 0.630942 \n", - "16 0.497607 0.569062 0.380621 0.0 16 0.468031 0.631191 \n", - "17 0.507053 0.580275 0.390214 0.0 17 0.458170 0.620013 \n", - "18 0.506170 0.578457 0.389007 0.0 18 0.460292 0.622816 \n", - "19 0.508543 0.581377 0.391374 0.0 19 0.459247 0.622094 \n", + " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", + "0 0.710609 0.819329 0.622614 0.0 0 1.144053 1.562842 \n", + "1 0.836989 0.946932 0.733583 0.0 1 0.532974 0.702971 \n", + "2 0.588745 0.704277 0.497290 0.0 2 0.495191 0.658636 \n", + "3 0.699847 0.818369 0.594933 0.0 3 0.475755 0.632438 \n", + "4 0.704670 0.828111 0.594259 0.0 4 0.460465 0.615198 \n", + "5 0.648891 0.755905 0.530240 0.0 5 0.458983 0.614499 \n", + "6 0.715093 0.839661 0.608006 0.0 6 0.459262 0.614916 \n", + "7 0.689927 0.807763 0.577745 0.0 7 0.455545 0.609835 \n", + "8 0.646029 0.753095 0.530006 0.0 8 0.457110 0.611970 \n", + "9 0.688786 0.806359 0.577159 0.0 9 0.456113 0.611694 \n", "\n", - " Loss RegLoss LR \n", - "0 1.916612 0.0 0.004087 \n", - "1 0.505523 0.0 0.021600 \n", - "2 0.347341 0.0 0.050152 \n", - "3 0.312197 0.0 0.078837 \n", - "4 0.313073 0.0 0.096699 \n", - "5 0.322044 0.0 0.099596 \n", - "6 0.316359 0.0 0.097137 \n", - "7 0.307173 0.0 0.092315 \n", - "8 0.308114 0.0 0.085371 \n", - "9 0.311271 0.0 0.076654 \n", - "10 0.301624 0.0 0.066600 \n", - "11 0.295672 0.0 0.055713 \n", - "12 0.287683 0.0 0.044541 \n", - "13 0.284241 0.0 0.033641 \n", - "14 0.280081 0.0 0.023563 \n", - "15 0.275439 0.0 0.014810 \n", - "16 0.276560 0.0 0.007821 \n", - "17 0.268218 0.0 0.002948 \n", - "18 0.268188 0.0 0.000434 \n", - "19 0.267627 0.0 0.000405 " + " Loss RegLoss LR \n", + "0 1.368656 0.0 0.012448 \n", + "1 0.343495 0.0 0.039781 \n", + "2 0.304316 0.0 0.040028 \n", + "3 0.283402 0.0 0.012695 \n", + "4 0.271323 0.0 0.004634 \n", + "5 0.270039 0.0 0.003871 \n", + "6 0.269727 0.0 0.002631 \n", + "7 0.266736 0.0 0.001389 \n", + "8 0.267934 0.0 0.000618 \n", + "9 0.267121 0.0 0.000613 " ] }, "execution_count": 9, @@ -2375,7 +2047,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "23b4bbc68cea4966bf719a33132a3726", + "model_id": "9c2666a108ac4123919f8203b5f548b1", "version_major": 2, "version_minor": 0 }, @@ -2400,7 +2072,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dd54d62acf464964b7d15974106128e3", + "model_id": "08c352843e98443a8b24b2713b9636b1", "version_major": 2, "version_minor": 0 }, @@ -2486,7 +2158,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c873800851d442818fb758ee0b8565b0", + "model_id": "fdb0f6945c704c788e4243c7789a7e29", "version_major": 2, "version_minor": 0 }, @@ -2497,61 +2169,61 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '1e41dda2-f7ca-4501-ae0d-394dbc69313f',\n", + " 'uid': 'cceaf554-f88b-47ac-b077-bd98eebd51bd',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([ 7.0392876, 9.7315445, 17.582043 , ..., 48.029076 , 46.43782 ,\n", - " 48.867878 ], dtype=float32)},\n", + " 'y': array([21.843597, 25.104948, 33.001038, ..., 46.6899 , 41.747295, 48.700737],\n", + " dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '8b1256ed-4f9c-4a39-98f6-d94c1af49272',\n", + " 'uid': 'bda8f5ed-a117-47cf-8298-3b7802a38dcc',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([64.257675, 64.03381 , 72.06434 , ..., 74.77048 , 70.81393 , 73.04162 ],\n", + " 'y': array([83.986046, 89.35564 , 74.86833 , ..., 74.58182 , 77.62551 , 77.05947 ],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'c9c42f86-4573-4aea-8a34-4980b55458a9',\n", + " 'uid': 'a0f939c9-0e13-45ce-8081-3dac7cf67c72',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", + " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '8704218f-c879-46d2-98f8-70840910069f',\n", + " 'uid': '555cf752-d8ea-41a9-8f47-596ee22a34be',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", + " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '91a4e9d7-9480-462b-b473-7f38c9371ea4',\n", + " 'uid': '09a782ca-96bb-4a77-80a7-fd42484a363d',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -2714,7 +2386,7 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_components\u001b[49m\u001b[43m(\u001b[49m\u001b[43mforecast\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2452\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2450\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2451\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2452\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2453\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2454\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2455\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2456\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2457\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2458\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2459\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2460\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2461\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2462\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2464\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2465\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2470\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2471\u001b[0m )\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2465\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2463\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2466\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2468\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2469\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2470\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2471\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2472\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2473\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2474\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2475\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2476\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2477\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2478\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2483\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2484\u001b[0m )\n", "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:332\u001b[0m, in \u001b[0;36mplot_components\u001b[0;34m(m, fcst, plot_configuration, df_name, one_period_per_season, figsize, resampler_active, plotly_static)\u001b[0m\n\u001b[1;32m 327\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m get_forecast_component_props(\n\u001b[1;32m 328\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst, df_name\u001b[38;5;241m=\u001b[39mdf_name, comp_name\u001b[38;5;241m=\u001b[39mcomp_name, plot_name\u001b[38;5;241m=\u001b[39mcomp[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplot_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 329\u001b[0m )\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto-regression\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlagged regressor\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name:\n\u001b[0;32m--> 332\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m \u001b[43mget_multiforecast_component_props\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcomp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 333\u001b[0m fig\u001b[38;5;241m.\u001b[39mupdate_layout(barmode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moverlay\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:603\u001b[0m, in \u001b[0;36mget_multiforecast_component_props\u001b[0;34m(fcst, comp_name, plot_name, multiplicative, bar, focus, num_overplot, **kwargs)\u001b[0m\n\u001b[1;32m 601\u001b[0m y \u001b[38;5;241m=\u001b[39m fcst[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcomp_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 602\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mvalues\n\u001b[0;32m--> 603\u001b[0m \u001b[43my\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bar:\n\u001b[1;32m 605\u001b[0m traces\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 606\u001b[0m go\u001b[38;5;241m.\u001b[39mBar(\n\u001b[1;32m 607\u001b[0m name\u001b[38;5;241m=\u001b[39mplot_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 613\u001b[0m )\n\u001b[1;32m 614\u001b[0m )\n", "\u001b[0;31mIndexError\u001b[0m: index -1 is out of bounds for axis 0 with size 0" @@ -2759,7 +2431,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "06ddcd4df2f94141aefe0f5056b5c617", + "model_id": "a09e98a70e3e4d4da2d2fe3865b8abc7", "version_major": 2, "version_minor": 0 }, @@ -2770,18 +2442,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': '2b2cb1c3-4a0e-4729-92d7-1167a376d732',\n", + " 'uid': '6d93394c-496d-4c52-8c79-e1a55e9bff0d',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([25.74136 , 14.885769], dtype=float32),\n", + " 'y': array([35.735615, 26.46712 ], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': '965409ff-c889-48d6-bae0-936574b46f88',\n", + " 'uid': '91e577de-4754-46f9-a832-ff20851064d6',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2789,15 +2461,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([-48.182846, -50.157616, -51.64516 , ..., -38.81324 , -42.12695 ,\n", - " -45.080776], dtype=float32),\n", + " 'y': array([-1.7568997 , -2.1306572 , -2.4605272 , ..., -0.36249205, -0.8088371 ,\n", + " -1.2453859 ], dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': 'ba3b298a-be66-40c5-8264-ef5b716d06bf',\n", + " 'uid': 'fbb9e766-a826-4066-b0a7-9bbd7391dedd',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2811,124 +2483,119 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([-3.04588795e+01, -3.54400597e+01, -3.99797134e+01, -4.43526154e+01,\n", - " -4.81006012e+01, -5.13112793e+01, -5.40508690e+01, -5.63052025e+01,\n", - " -5.81308632e+01, -5.93055077e+01, -6.00412445e+01, -6.02870750e+01,\n", - " -5.99476738e+01, -5.91321754e+01, -5.79114304e+01, -5.63154335e+01,\n", - " -5.43354149e+01, -5.17919426e+01, -4.89967918e+01, -4.60626144e+01,\n", - " -4.26569977e+01, -3.94107780e+01, -3.56459961e+01, -3.19178200e+01,\n", - " -2.82482281e+01, -2.45639687e+01, -2.09295406e+01, -1.71104145e+01,\n", - " -1.35622263e+01, -1.01446161e+01, -7.01516199e+00, -4.13522530e+00,\n", - " -1.30050945e+00, 1.12308824e+00, 3.24085855e+00, 5.06825876e+00,\n", - " 6.57082605e+00, 7.81317091e+00, 8.63896179e+00, 9.19974041e+00,\n", - " 9.46362782e+00, 9.35620880e+00, 8.95828724e+00, 8.39426994e+00,\n", - " 7.48970175e+00, 6.47928476e+00, 5.27563381e+00, 3.69769454e+00,\n", - " 2.23112702e+00, 5.54561198e-01, -9.77979958e-01, -2.75594759e+00,\n", - " -4.51804447e+00, -6.08968639e+00, -7.68710136e+00, -9.01867199e+00,\n", - " -1.04195356e+01, -1.15890474e+01, -1.25064888e+01, -1.32924528e+01,\n", - " -1.37729158e+01, -1.40936012e+01, -1.41434650e+01, -1.39483709e+01,\n", - " -1.34957485e+01, -1.28450623e+01, -1.19179354e+01, -1.07594900e+01,\n", - " -9.40466499e+00, -7.81799698e+00, -6.03295660e+00, -4.35296869e+00,\n", - " -2.36515474e+00, -4.06748503e-01, 1.75163662e+00, 3.73803329e+00,\n", - " 5.86639977e+00, 7.98941374e+00, 9.86133385e+00, 1.17813129e+01,\n", - " 1.33550673e+01, 1.49084988e+01, 1.61548271e+01, 1.73153419e+01,\n", - " 1.81719952e+01, 1.87052574e+01, 1.89303913e+01, 1.88749886e+01,\n", - " 1.85987663e+01, 1.79191532e+01, 1.69609566e+01, 1.57270803e+01,\n", - " 1.42448225e+01, 1.23795385e+01, 1.02109413e+01, 8.01280880e+00,\n", - " 5.49813223e+00, 2.89395428e+00, 5.37771396e-02, -3.00232887e+00,\n", - " -6.02308273e+00, -8.95809174e+00, -1.18841352e+01, -1.48639030e+01,\n", - " -1.78287868e+01, -2.03142662e+01, -2.28633080e+01, -2.51014042e+01,\n", - " -2.72459736e+01, -2.90301781e+01, -3.02393055e+01, -3.12937450e+01,\n", - " -3.19238663e+01, -3.21995316e+01, -3.19541264e+01, -3.12862282e+01,\n", - " -3.01354771e+01, -2.86372204e+01, -2.65992775e+01, -2.41690311e+01,\n", - " -2.13332558e+01, -1.81720409e+01, -1.45338621e+01, -1.03474331e+01,\n", - " -6.08415222e+00, -1.40853214e+00, 3.29980421e+00, 8.35535431e+00,\n", - " 1.37836151e+01, 1.89796047e+01, 2.41642418e+01, 2.93892212e+01,\n", - " 3.48417130e+01, 4.02113190e+01, 4.50756607e+01, 4.96866379e+01,\n", - " 5.40651703e+01, 5.83653717e+01, 6.21502800e+01, 6.54130478e+01,\n", - " 6.82579117e+01, 7.07123795e+01, 7.26812973e+01, 7.40298157e+01,\n", - " 7.47914352e+01, 7.49953308e+01, 7.46389465e+01, 7.36509781e+01,\n", - " 7.20302200e+01, 6.99344711e+01, 6.73451157e+01, 6.42106323e+01,\n", - " 6.03940163e+01, 5.62932777e+01, 5.17331505e+01, 4.69316330e+01,\n", - " 4.13555679e+01, 3.55166130e+01, 2.97218113e+01, 2.36869488e+01,\n", - " 1.76951599e+01, 1.11073389e+01, 4.76392841e+00, -1.35005784e+00,\n", - " -7.75571394e+00, -1.36067734e+01, -1.97326603e+01, -2.53451424e+01],\n", - " dtype=float32),\n", + " 'y': array([ 7.755577 , 7.5237527 , 7.224617 , 6.815465 , 6.314899 ,\n", + " 5.7965226 , 5.222597 , 4.60115 , 3.888913 , 3.192386 ,\n", + " 2.480555 , 1.7581066 , 0.9770121 , 0.21064605, -0.4857402 ,\n", + " -1.1512595 , -1.7806304 , -2.4084048 , -2.9580166 , -3.4257982 ,\n", + " -3.855735 , -4.175071 , -4.4489365 , -4.6284823 , -4.724985 ,\n", + " -4.7434278 , -4.6815186 , -4.530946 , -4.302239 , -4.0017447 ,\n", + " -3.6467676 , -3.2275088 , -2.7205632 , -2.1764457 , -1.6195372 ,\n", + " -1.0332191 , -0.3954899 , 0.27314964, 0.9262204 , 1.5400877 ,\n", + " 2.1368022 , 2.7368646 , 3.2993367 , 3.788525 , 4.2504687 ,\n", + " 4.6347704 , 4.9488535 , 5.214139 , 5.3700595 , 5.456191 ,\n", + " 5.4460826 , 5.3499737 , 5.1563053 , 4.8846197 , 4.5234604 ,\n", + " 4.1024094 , 3.5617476 , 2.9600484 , 2.3084648 , 1.5802894 ,\n", + " 0.81475705, -0.04111661, -0.90458584, -1.7624184 , -2.661798 ,\n", + " -3.53263 , -4.4244337 , -5.295709 , -6.1333375 , -6.931251 ,\n", + " -7.696276 , -8.319346 , -8.925283 , -9.427417 , -9.868788 ,\n", + " -10.188685 , -10.430543 , -10.565057 , -10.580534 , -10.500571 ,\n", + " -10.296801 , -9.9909 , -9.583059 , -9.07224 , -8.425532 ,\n", + " -7.701064 , -6.915518 , -6.0554175 , -5.111462 , -4.0590043 ,\n", + " -3.0526552 , -1.9540824 , -0.86603564, 0.3103157 , 1.494285 ,\n", + " 2.560513 , 3.6657102 , 4.700126 , 5.7225103 , 6.720365 ,\n", + " 7.6138163 , 8.40019 , 9.098643 , 9.724774 , 10.263643 ,\n", + " 10.650285 , 10.948275 , 11.123994 , 11.187449 , 11.121981 ,\n", + " 10.951801 , 10.656306 , 10.26268 , 9.729534 , 9.129556 ,\n", + " 8.439994 , 7.6128716 , 6.767242 , 5.794735 , 4.8272123 ,\n", + " 3.8182733 , 2.7837412 , 1.6946272 , 0.55752414, -0.5014487 ,\n", + " -1.5780605 , -2.5846512 , -3.583228 , -4.562805 , -5.417882 ,\n", + " -6.1975856 , -6.9047456 , -7.558581 , -8.112523 , -8.532403 ,\n", + " -8.8463745 , -9.0614 , -9.167561 , -9.156995 , -9.042185 ,\n", + " -8.828875 , -8.505391 , -8.057592 , -7.5299797 , -6.938876 ,\n", + " -6.275693 , -5.5465274 , -4.701907 , -3.8042374 , -2.929176 ,\n", + " -2.0293174 , -1.1140192 , -0.13168602, 0.7801494 , 1.7001534 ,\n", + " 2.5384672 , 3.4228206 , 4.248964 , 4.967284 , 5.6338806 ,\n", + " 6.2010517 , 6.737668 , 7.1634216 , 7.488221 , 7.7340417 ,\n", + " 7.8822103 , 7.9357567 , 7.8853316 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '6deeeb47-3be1-4b89-a492-0e6bf00cdd53',\n", + " 'uid': '11f43753-196c-416c-8679-caab4f55210d',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([-0.43253064, 0.38781527, 1.1857711 , ..., -2.2786856 , -1.8861564 ,\n", - " -1.1711025 ], dtype=float32),\n", + " 'y': array([ 1.5749581 , 0.68877584, -0.0385443 , ..., 3.668294 , 3.1973646 ,\n", + " 2.3746142 ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': 'b5b32b83-f643-4973-8db7-1797073f910f',\n", + " 'uid': '82f6c088-001e-487b-947b-2b32fbf6b06c',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([-21.528275, -22.261412, -22.15886 , ..., -17.52789 , -18.709206,\n", - " -20.179296], dtype=float32),\n", + " 'y': array([ 1.621103 , 0.41932815, -0.4366651 , ..., 4.0205083 , 3.520041 ,\n", + " 2.6508992 ], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': 'dea67a1e-e499-4e43-a97e-0fb70895854a',\n", + " 'uid': 'c3664db2-136a-40d7-9103-625733dda176',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " 'y': array([-0.3051259 , -0.20342994, 0.05545649, 0.13164242, 0.2857538 ,\n", + " 0.06021553, -0.3720306 , -0.01876787, -0.00064597, 0.04794757],\n", + " dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': '78f9e1a7-125b-4563-aaec-5ff3fe054d54',\n", + " 'uid': 'd28e2fd9-0522-488a-b89a-02b1296bae1a',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 0.4898405 , 1.880786 , 2.1640918 , -0.2928444 , 0.86554664,\n", - " -0.08299019, 0.70314807, 1.0031208 , 0.90763193, 0.4682139 ,\n", - " -1.1896216 , -1.6901426 , -1.2109685 , 0.9098389 , -0.64848685,\n", - " 0.9634216 , 0.91694885, 2.0049295 , 2.8199239 , 0.83436155,\n", - " 1.654415 , 2.4778936 , 0.64203995, 2.3134997 , 1.7692485 ,\n", - " 1.1947386 , 0.62410027, 0.79597855, 2.8871663 , 0.6992378 ,\n", - " 0.69474286, 1.997743 , 2.4678695 ], dtype=float32),\n", + " 'y': array([ 6.2832564e-01, -2.3228288e-01, 6.6017294e-01, 3.3255139e-01,\n", + " 5.0744390e-01, 1.1816436e-01, -1.5548144e-01, 1.6358766e-01,\n", + " -1.7810777e-01, 3.2371131e-01, 4.4875324e-01, -3.1604797e-01,\n", + " -1.3501082e-03, -9.3391158e-02, 8.1444037e-01, -7.3939008e-01,\n", + " 4.2238832e-01, 5.4563276e-02, 3.5837620e-01, -5.2361876e-02,\n", + " -5.4710191e-01, -7.3065239e-01, -3.4761795e-01, 4.7559822e-01,\n", + " 2.0330952e-02, 2.5780448e-01, 1.0076398e-01, 3.2984644e-01,\n", + " 2.2101782e-01, 2.5692052e-01, -8.7424242e-01, -8.4744475e-04,\n", + " 3.0343091e-01], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': 'df559219-3f3e-45df-864c-8eebbaeadb67',\n", + " 'uid': 'ec08d1ed-660d-43c4-8112-e6bd03345ae9',\n", " 'width': 0.8,\n", - " 'x': array(['Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1', 'Veterans Day_+0',\n", - " 'Veterans Day_+1', 'Veterans Day_-1', 'Martin Luther King Jr. Day_+0',\n", + " 'x': array(['Memorial Day_+0', 'Memorial Day_+1', 'Memorial Day_-1',\n", + " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", + " \"Washington's Birthday_-1\", 'Columbus Day_+0', 'Columbus Day_+1',\n", + " 'Columbus Day_-1', 'Martin Luther King Jr. Day_+0',\n", " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", - " \"Washington's Birthday_-1\", 'Independence Day_+0',\n", - " 'Independence Day_+1', 'Independence Day_-1', 'Memorial Day_+0',\n", - " 'Memorial Day_+1', 'Memorial Day_-1', 'Columbus Day_+0',\n", - " 'Columbus Day_+1', 'Columbus Day_-1', 'Thanksgiving_+0',\n", - " 'Thanksgiving_+1', 'Thanksgiving_-1', 'Christmas Day_+0',\n", - " 'Christmas Day_+1', 'Christmas Day_-1'], dtype=object),\n", + " 'Thanksgiving_+0', 'Thanksgiving_+1', 'Thanksgiving_-1',\n", + " 'Christmas Day_+0', 'Christmas Day_+1', 'Christmas Day_-1',\n", + " 'Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", + " 'Independence Day_+0', 'Independence Day_+1', 'Independence Day_-1',\n", + " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [-6.505150318145752, 3.078960418701172, -3.0603911876678467,\n", - " -1.937178373336792, -0.9162442684173584, -3.922412395477295,\n", - " -43.94681930541992, 48.15086364746094, -47.38690948486328,\n", - " 0.037264175713062286, 4.726099967956543, 0.49383774399757385,\n", - " -0.8578076362609863, -8.193577766418457, 8.767333030700684,\n", - " -2.2701916694641113, -0.19705480337142944, -1.0239486694335938,\n", - " 2.4958767890930176, 5.431707859039307, -3.5964465141296387,\n", - " -3.9246764183044434, 3.1682686805725098, 1.5535764694213867,\n", - " -3.401339054107666, 0.7919614911079407, 1.1661392450332642,\n", - " 0.8668169975280762, -6.069958686828613, -1.4564253091812134],\n", + " 'y': [2.815302848815918, -1.2508420944213867, -2.8079898357391357,\n", + " 1.0754282474517822, 1.3034065961837769, 3.282367706298828,\n", + " -0.20270073413848877, 5.353240489959717, 3.5992088317871094,\n", + " 1.0151381492614746, 2.7972025871276855, 0.7658103108406067,\n", + " 0.6110072731971741, -3.8502631187438965, -1.1796778440475464,\n", + " -1.184556484222412, -1.5890964269638062, 6.425841331481934,\n", + " -0.3609931468963623, 5.461928367614746, -3.2548437118530273,\n", + " -7.143401622772217, 4.949648857116699, -2.3242950439453125,\n", + " -3.5402631759643555, 1.2076290845870972, 5.756880283355713,\n", + " -1.5092906951904297, -2.149829149246216, 2.4397971630096436],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n", From bc64a527eecc55adb1553351c981dfabf7bdfd99 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Wed, 28 Aug 2024 18:24:19 -0700 Subject: [PATCH 29/39] improve num_training calculation for lr-finder and remove loss-min for lr calc --- neuralprophet/configure.py | 12 +- neuralprophet/forecaster.py | 6 +- neuralprophet/utils.py | 9 +- tests/debug/debug-energy-price-hourly.ipynb | 586 ++++++++++---------- 4 files changed, 310 insertions(+), 303 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 467922213..659888651 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -265,21 +265,23 @@ def set_scheduler(self): self.scheduler, torch.optim.lr_scheduler.LRScheduler ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" - def set_lr_finder_args(self, dataset_size, num_batches): + def set_lr_finder_args(self, main_training_epochs: int, batches_per_epoch: int): """ Set the lr_finder_args. This is the range of learning rates to test. """ - num_training = 100 + int(np.log10(dataset_size) * 20) - if num_batches < num_training: + main_training_total_steps = main_training_epochs * batches_per_epoch + # main_training_total_steps is around 1e3 to 1e6 -> num_training 100 to 400 + num_training = 100 + int(np.log10(1 + main_training_total_steps / 1000) * 100) + if batches_per_epoch < num_training: log.warning( - f"Learning rate finder: The number of batches ({num_batches}) is too small than the required number \ + f"Learning rate finder: The number of batches per epoch ({batches_per_epoch}) is too small than the required number \ for the learning rate finder ({num_training}). The results might not be optimal." ) # num_training = num_batches self.lr_finder_args.update( { - "min_lr": 1e-8, + "min_lr": 1e-7, "max_lr": 1e1, "num_training": num_training, "early_stop_threshold": None, diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 746549462..33beb416c 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2817,7 +2817,7 @@ def _train( progress_bar_enabled=progress_bar_enabled, metrics_enabled=metrics_enabled, checkpointing_enabled=checkpointing_enabled, - num_batches_per_epoch=len(train_loader), + num_batches_per_epoch=batches_per_epoch, deterministic=deterministic, ) @@ -2825,7 +2825,9 @@ def _train( if not self.config_train.learning_rate: log.info("No Learning Rate provided. Activating learning rate finder") # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=batches_per_epoch) + self.config_train.set_lr_finder_args( + main_training_epochs=self.config_train.epochs, batches_per_epoch=batches_per_epoch + ) log.info(f"Learning rate finder ---- ARGs: {self.config_train.lr_finder_args}") self.model.finding_lr = True tuner = Tuner(self.trainer) diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 309d9098f..39a82f84f 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -795,7 +795,7 @@ def smooth_loss_and_suggest(lr_finder, window=10): try: # Find the steepest gradient and the minimum loss after that suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] - suggestion_minimum = lr[np.argmin(loss_smooth)] + suggestion_minimum = lr[np.argmin(np.array(lr_finder_results["loss"]))] except ValueError: log.error( f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " @@ -807,12 +807,11 @@ def smooth_loss_and_suggest(lr_finder, window=10): log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") - log.info(f"Learning rate finder ---- minimum: {suggestion_minimum}") - if suggestion_steepest is not None and suggestion_minimum is not None and suggestion_default is not None: + log.info(f"Learning rate finder ---- minimum (not used): {suggestion_minimum}") + if suggestion_steepest is not None and suggestion_default is not None: log_suggestion_smooth = np.log(suggestion_steepest) - log_suggestion_minimum = np.log(suggestion_minimum) log_suggestion_default = np.log(suggestion_default) - lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_minimum + log_suggestion_default) / 3) + lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") elif suggestion_steepest is None and suggestion_default is None: log.error("Automatic learning rate test failed. Please set manually the learning rate.") diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index ab4485f1f..f78de7a04 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -171,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -278,7 +278,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3c6d261e96524335a24f00923ad36c02", + "model_id": "ded17dc7d6e940bfb29321cd972603b6", "version_major": 2, "version_minor": 0 }, @@ -294,19 +294,19 @@ "output_type": "stream", "text": [ "INFO - (NP.forecaster._train) - No Learning Rate provided. Activating learning rate finder\n", - "WARNING - (NP.config.set_lr_finder_args) - Learning rate finder: The number of batches (87) is too small than the required number for the learning rate finder (168). The results might not be optimal.\n", - "INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-08, 'max_lr': 10.0, 'num_training': 168, 'early_stop_threshold': None, 'mode': 'exponential'}\n" + "WARNING - (NP.config.set_lr_finder_args) - Learning rate finder: The number of batches per epoch (87) is too small than the required number for the learning rate finder (127). The results might not be optimal.\n", + "INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-07, 'max_lr': 10.0, 'num_training': 127, 'early_stop_threshold': None, 'mode': 'exponential'}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "35957cbfc5044d2eab5eb3fe1ccee7c8", + "model_id": "5c81445c09e94d1f94a5e2a46d3b581f", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Finding best initial lr: 0%| | 0/168 [00:00\n", " \n", " 0\n", - " 0.710609\n", - " 0.819329\n", - " 0.622614\n", + " 0.518220\n", + " 0.618144\n", + " 0.463509\n", " 0.0\n", " 0\n", - " 1.144053\n", - " 1.562842\n", - " 1.368656\n", + " 1.071776\n", + " 1.369143\n", + " 1.176725\n", " 0.0\n", - " 0.012448\n", + " 0.002857\n", " \n", " \n", " 1\n", - " 0.836989\n", - " 0.946932\n", - " 0.733583\n", + " 0.544104\n", + " 0.619763\n", + " 0.485463\n", " 0.0\n", " 1\n", - " 0.532974\n", - " 0.702971\n", - " 0.343495\n", + " 0.551133\n", + " 0.734627\n", + " 0.364626\n", " 0.0\n", - " 0.039781\n", + " 0.009130\n", " \n", " \n", " 2\n", - " 0.588745\n", - " 0.704277\n", - " 0.497290\n", + " 0.479003\n", + " 0.555191\n", + " 0.394251\n", " 0.0\n", " 2\n", - " 0.495191\n", - " 0.658636\n", - " 0.304316\n", + " 0.495316\n", + " 0.666528\n", + " 0.309311\n", " 0.0\n", - " 0.040028\n", + " 0.009187\n", " \n", " \n", " 3\n", - " 0.699847\n", - " 0.818369\n", - " 0.594933\n", + " 0.516385\n", + " 0.592248\n", + " 0.435629\n", " 0.0\n", " 3\n", - " 0.475755\n", - " 0.632438\n", - " 0.283402\n", + " 0.481878\n", + " 0.650643\n", + " 0.295138\n", " 0.0\n", - " 0.012695\n", + " 0.002914\n", " \n", " \n", " 4\n", - " 0.704670\n", - " 0.828111\n", - " 0.594259\n", + " 0.492940\n", + " 0.569359\n", + " 0.405212\n", " 0.0\n", " 4\n", - " 0.460465\n", - " 0.615198\n", - " 0.271323\n", + " 0.473522\n", + " 0.639349\n", + " 0.288111\n", " 0.0\n", - " 0.004634\n", + " 0.001064\n", " \n", " \n", " 5\n", - " 0.648891\n", - " 0.755905\n", - " 0.530240\n", + " 0.509749\n", + " 0.587457\n", + " 0.422531\n", " 0.0\n", " 5\n", - " 0.458983\n", - " 0.614499\n", - " 0.270039\n", + " 0.471753\n", + " 0.637636\n", + " 0.286060\n", " 0.0\n", - " 0.003871\n", + " 0.000888\n", " \n", " \n", " 6\n", - " 0.715093\n", - " 0.839661\n", - " 0.608006\n", + " 0.512940\n", + " 0.592517\n", + " 0.425102\n", " 0.0\n", " 6\n", - " 0.459262\n", - " 0.614916\n", - " 0.269727\n", + " 0.470406\n", + " 0.636353\n", + " 0.284498\n", " 0.0\n", - " 0.002631\n", + " 0.000604\n", " \n", " \n", " 7\n", - " 0.689927\n", - " 0.807763\n", - " 0.577745\n", + " 0.507922\n", + " 0.586603\n", + " 0.418404\n", " 0.0\n", " 7\n", - " 0.455545\n", - " 0.609835\n", - " 0.266736\n", + " 0.470709\n", + " 0.635788\n", + " 0.283936\n", " 0.0\n", - " 0.001389\n", + " 0.000319\n", " \n", " \n", " 8\n", - " 0.646029\n", - " 0.753095\n", - " 0.530006\n", + " 0.509402\n", + " 0.588127\n", + " 0.420135\n", " 0.0\n", " 8\n", - " 0.457110\n", - " 0.611970\n", - " 0.267934\n", + " 0.469353\n", + " 0.635285\n", + " 0.282983\n", " 0.0\n", - " 0.000618\n", + " 0.000142\n", " \n", " \n", " 9\n", - " 0.688786\n", - " 0.806359\n", - " 0.577159\n", + " 0.510809\n", + " 0.589129\n", + " 0.420723\n", " 0.0\n", " 9\n", - " 0.456113\n", - " 0.611694\n", - " 0.267121\n", + " 0.468861\n", + " 0.634873\n", + " 0.283335\n", " 0.0\n", - " 0.000613\n", + " 0.000141\n", " \n", " \n", "\n", @@ -1927,28 +1927,28 @@ ], "text/plain": [ " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", - "0 0.710609 0.819329 0.622614 0.0 0 1.144053 1.562842 \n", - "1 0.836989 0.946932 0.733583 0.0 1 0.532974 0.702971 \n", - "2 0.588745 0.704277 0.497290 0.0 2 0.495191 0.658636 \n", - "3 0.699847 0.818369 0.594933 0.0 3 0.475755 0.632438 \n", - "4 0.704670 0.828111 0.594259 0.0 4 0.460465 0.615198 \n", - "5 0.648891 0.755905 0.530240 0.0 5 0.458983 0.614499 \n", - "6 0.715093 0.839661 0.608006 0.0 6 0.459262 0.614916 \n", - "7 0.689927 0.807763 0.577745 0.0 7 0.455545 0.609835 \n", - "8 0.646029 0.753095 0.530006 0.0 8 0.457110 0.611970 \n", - "9 0.688786 0.806359 0.577159 0.0 9 0.456113 0.611694 \n", + "0 0.518220 0.618144 0.463509 0.0 0 1.071776 1.369143 \n", + "1 0.544104 0.619763 0.485463 0.0 1 0.551133 0.734627 \n", + "2 0.479003 0.555191 0.394251 0.0 2 0.495316 0.666528 \n", + "3 0.516385 0.592248 0.435629 0.0 3 0.481878 0.650643 \n", + "4 0.492940 0.569359 0.405212 0.0 4 0.473522 0.639349 \n", + "5 0.509749 0.587457 0.422531 0.0 5 0.471753 0.637636 \n", + "6 0.512940 0.592517 0.425102 0.0 6 0.470406 0.636353 \n", + "7 0.507922 0.586603 0.418404 0.0 7 0.470709 0.635788 \n", + "8 0.509402 0.588127 0.420135 0.0 8 0.469353 0.635285 \n", + "9 0.510809 0.589129 0.420723 0.0 9 0.468861 0.634873 \n", "\n", " Loss RegLoss LR \n", - "0 1.368656 0.0 0.012448 \n", - "1 0.343495 0.0 0.039781 \n", - "2 0.304316 0.0 0.040028 \n", - "3 0.283402 0.0 0.012695 \n", - "4 0.271323 0.0 0.004634 \n", - "5 0.270039 0.0 0.003871 \n", - "6 0.269727 0.0 0.002631 \n", - "7 0.266736 0.0 0.001389 \n", - "8 0.267934 0.0 0.000618 \n", - "9 0.267121 0.0 0.000613 " + "0 1.176725 0.0 0.002857 \n", + "1 0.364626 0.0 0.009130 \n", + "2 0.309311 0.0 0.009187 \n", + "3 0.295138 0.0 0.002914 \n", + "4 0.288111 0.0 0.001064 \n", + "5 0.286060 0.0 0.000888 \n", + "6 0.284498 0.0 0.000604 \n", + "7 0.283936 0.0 0.000319 \n", + "8 0.282983 0.0 0.000142 \n", + "9 0.283335 0.0 0.000141 " ] }, "execution_count": 9, @@ -1973,7 +1973,13 @@ "\n", "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", "\n", - "\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", "\n", @@ -2047,7 +2053,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9c2666a108ac4123919f8203b5f548b1", + "model_id": "f1e4231ad84f4ce2a3b3152a04780df8", "version_major": 2, "version_minor": 0 }, @@ -2072,7 +2078,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "08c352843e98443a8b24b2713b9636b1", + "model_id": "6eb4c8947cf94fceb753083bf4633d92", "version_major": 2, "version_minor": 0 }, @@ -2158,7 +2164,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fdb0f6945c704c788e4243c7789a7e29", + "model_id": "a7458da38e7d4fe88012eb289b1bcf6e", "version_major": 2, "version_minor": 0 }, @@ -2169,14 +2175,14 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'cceaf554-f88b-47ac-b077-bd98eebd51bd',\n", + " 'uid': '1e485c1d-dae9-439f-97e8-f8960bf19265',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([21.843597, 25.104948, 33.001038, ..., 46.6899 , 41.747295, 48.700737],\n", + " 'y': array([-8.401451, -8.331238, -7.641697, ..., 35.0834 , 31.378742, 26.125694],\n", " dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", @@ -2184,46 +2190,46 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'bda8f5ed-a117-47cf-8298-3b7802a38dcc',\n", + " 'uid': 'ddeebab6-2948-4ed2-839d-80f25bcbbb46',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([83.986046, 89.35564 , 74.86833 , ..., 74.58182 , 77.62551 , 77.05947 ],\n", + " 'y': array([70.66735 , 76.33557 , 73.91716 , ..., 72.79848 , 76.60592 , 75.965355],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'a0f939c9-0e13-45ce-8081-3dac7cf67c72',\n", + " 'uid': '3979d61f-ad4c-49eb-92f2-0519eec19b62',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", + " 'y': array([39.514854, 43.52948 , 40.765232, ..., 61.876595, 65.44407 , 61.56613 ],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '555cf752-d8ea-41a9-8f47-596ee22a34be',\n", + " 'uid': 'b8c0ed1f-761f-44db-a774-93abc8ab8338',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", + " 'y': array([39.514854, 43.52948 , 40.765232, ..., 61.876595, 65.44407 , 61.56613 ],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '09a782ca-96bb-4a77-80a7-fd42484a363d',\n", + " 'uid': '46d51f13-dea8-4bc0-b1fa-e458374a0cbc',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -2431,7 +2437,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a09e98a70e3e4d4da2d2fe3865b8abc7", + "model_id": "4549df8821ee4b649e748ccf20cbb9a9", "version_major": 2, "version_minor": 0 }, @@ -2442,18 +2448,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': '6d93394c-496d-4c52-8c79-e1a55e9bff0d',\n", + " 'uid': 'a87ad6c0-8302-4d80-a053-3de55909d7d9',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([35.735615, 26.46712 ], dtype=float32),\n", + " 'y': array([44.171093, 46.755657], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': '91e577de-4754-46f9-a832-ff20851064d6',\n", + " 'uid': '687528c8-278a-4dba-a432-7580315842b3',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2461,15 +2467,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([-1.7568997 , -2.1306572 , -2.4605272 , ..., -0.36249205, -0.8088371 ,\n", - " -1.2453859 ], dtype=float32),\n", + " 'y': array([3.5837142, 3.9962187, 4.3687215, ..., 2.132554 , 2.5801535, 3.0329647],\n", + " dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': 'fbb9e766-a826-4066-b0a7-9bbd7391dedd',\n", + " 'uid': '9bc616f5-17ad-4fc4-b88f-bb47bdcc4e5f',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2483,119 +2489,117 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([ 7.755577 , 7.5237527 , 7.224617 , 6.815465 , 6.314899 ,\n", - " 5.7965226 , 5.222597 , 4.60115 , 3.888913 , 3.192386 ,\n", - " 2.480555 , 1.7581066 , 0.9770121 , 0.21064605, -0.4857402 ,\n", - " -1.1512595 , -1.7806304 , -2.4084048 , -2.9580166 , -3.4257982 ,\n", - " -3.855735 , -4.175071 , -4.4489365 , -4.6284823 , -4.724985 ,\n", - " -4.7434278 , -4.6815186 , -4.530946 , -4.302239 , -4.0017447 ,\n", - " -3.6467676 , -3.2275088 , -2.7205632 , -2.1764457 , -1.6195372 ,\n", - " -1.0332191 , -0.3954899 , 0.27314964, 0.9262204 , 1.5400877 ,\n", - " 2.1368022 , 2.7368646 , 3.2993367 , 3.788525 , 4.2504687 ,\n", - " 4.6347704 , 4.9488535 , 5.214139 , 5.3700595 , 5.456191 ,\n", - " 5.4460826 , 5.3499737 , 5.1563053 , 4.8846197 , 4.5234604 ,\n", - " 4.1024094 , 3.5617476 , 2.9600484 , 2.3084648 , 1.5802894 ,\n", - " 0.81475705, -0.04111661, -0.90458584, -1.7624184 , -2.661798 ,\n", - " -3.53263 , -4.4244337 , -5.295709 , -6.1333375 , -6.931251 ,\n", - " -7.696276 , -8.319346 , -8.925283 , -9.427417 , -9.868788 ,\n", - " -10.188685 , -10.430543 , -10.565057 , -10.580534 , -10.500571 ,\n", - " -10.296801 , -9.9909 , -9.583059 , -9.07224 , -8.425532 ,\n", - " -7.701064 , -6.915518 , -6.0554175 , -5.111462 , -4.0590043 ,\n", - " -3.0526552 , -1.9540824 , -0.86603564, 0.3103157 , 1.494285 ,\n", - " 2.560513 , 3.6657102 , 4.700126 , 5.7225103 , 6.720365 ,\n", - " 7.6138163 , 8.40019 , 9.098643 , 9.724774 , 10.263643 ,\n", - " 10.650285 , 10.948275 , 11.123994 , 11.187449 , 11.121981 ,\n", - " 10.951801 , 10.656306 , 10.26268 , 9.729534 , 9.129556 ,\n", - " 8.439994 , 7.6128716 , 6.767242 , 5.794735 , 4.8272123 ,\n", - " 3.8182733 , 2.7837412 , 1.6946272 , 0.55752414, -0.5014487 ,\n", - " -1.5780605 , -2.5846512 , -3.583228 , -4.562805 , -5.417882 ,\n", - " -6.1975856 , -6.9047456 , -7.558581 , -8.112523 , -8.532403 ,\n", - " -8.8463745 , -9.0614 , -9.167561 , -9.156995 , -9.042185 ,\n", - " -8.828875 , -8.505391 , -8.057592 , -7.5299797 , -6.938876 ,\n", - " -6.275693 , -5.5465274 , -4.701907 , -3.8042374 , -2.929176 ,\n", - " -2.0293174 , -1.1140192 , -0.13168602, 0.7801494 , 1.7001534 ,\n", - " 2.5384672 , 3.4228206 , 4.248964 , 4.967284 , 5.6338806 ,\n", - " 6.2010517 , 6.737668 , 7.1634216 , 7.488221 , 7.7340417 ,\n", - " 7.8822103 , 7.9357567 , 7.8853316 ], dtype=float32),\n", + " 'y': array([ 3.6775422 , 3.4448764 , 3.1795988 , 2.856169 , 2.5034661 ,\n", + " 2.132053 , 1.7310926 , 1.3034438 , 0.8253559 , 0.36245304,\n", + " -0.12688631, -0.6307145 , -1.1616575 , -1.7126511 , -2.2267792 ,\n", + " -2.7566912 , -3.2863293 , -3.8237705 , -4.32482 , -4.836178 ,\n", + " -5.361228 , -5.8546944 , -6.334524 , -6.7697473 , -7.2097106 ,\n", + " -7.6307163 , -8.015902 , -8.410988 , -8.759339 , -9.097741 ,\n", + " -9.40439 , -9.673478 , -9.938627 , -10.160235 , -10.360814 ,\n", + " -10.533252 , -10.672475 , -10.790765 , -10.871719 , -10.92612 ,\n", + " -10.949665 , -10.941065 , -10.902167 , -10.831596 , -10.7297535 ,\n", + " -10.584343 , -10.400615 , -10.213366 , -9.973349 , -9.703287 ,\n", + " -9.400256 , -9.039257 , -8.638971 , -8.218772 , -7.76986 ,\n", + " -7.29001 , -6.7592626 , -6.205373 , -5.610344 , -5.011173 ,\n", + " -4.306703 , -3.6050465 , -2.8899956 , -2.1421108 , -1.3980246 ,\n", + " -0.55479425, 0.25694594, 1.0867394 , 1.9551147 , 2.7880616 ,\n", + " 3.6983426 , 4.5652947 , 5.432634 , 6.2935643 , 7.1280026 ,\n", + " 7.9856586 , 8.809366 , 9.650012 , 10.417858 , 11.145185 ,\n", + " 11.858826 , 12.523886 , 13.137134 , 13.699438 , 14.241183 ,\n", + " 14.707007 , 15.098125 , 15.426196 , 15.693906 , 15.901334 ,\n", + " 16.017977 , 16.07477 , 16.05878 , 15.968486 , 15.792491 ,\n", + " 15.547381 , 15.236185 , 14.865028 , 14.416378 , 13.873076 ,\n", + " 13.277334 , 12.638401 , 11.962018 , 11.21623 , 10.383526 ,\n", + " 9.558632 , 8.688934 , 7.814783 , 6.8657885 , 5.8782206 ,\n", + " 4.93936 , 3.9818935 , 3.0527442 , 2.0779295 , 1.1511891 ,\n", + " 0.24442616, -0.68556976, -1.5159744 , -2.3555033 , -3.1236107 ,\n", + " -3.845236 , -4.5020337 , -5.122622 , -5.7106447 , -6.2054214 ,\n", + " -6.647813 , -7.003909 , -7.302908 , -7.544511 , -7.7055807 ,\n", + " -7.7959027 , -7.8254924 , -7.783019 , -7.673332 , -7.5097384 ,\n", + " -7.290905 , -7.017191 , -6.67524 , -6.2961516 , -5.888133 ,\n", + " -5.437577 , -4.9315596 , -4.396648 , -3.855047 , -3.313541 ,\n", + " -2.7510498 , -2.1820972 , -1.5792464 , -0.9815789 , -0.42582962,\n", + " 0.10404018, 0.6269935 , 1.1417981 , 1.6131238 , 2.0423858 ,\n", + " 2.4490814 , 2.8360126 , 3.1679304 , 3.4505346 , 3.676464 ,\n", + " 3.8737729 , 4.024116 , 4.1148095 , 4.157773 , 4.150377 ,\n", + " 4.109914 , 4.006662 , 3.8584704 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '11f43753-196c-416c-8679-caab4f55210d',\n", + " 'uid': '4495dca3-32d1-4ba1-8c78-6e1a1b0af24c',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([ 1.5749581 , 0.68877584, -0.0385443 , ..., 3.668294 , 3.1973646 ,\n", - " 2.3746142 ], dtype=float32),\n", + " 'y': array([ 0.96416605, -0.55186653, -1.687766 , ..., 3.8534515 , 3.2716925 ,\n", + " 2.1252832 ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': '82f6c088-001e-487b-947b-2b32fbf6b06c',\n", + " 'uid': 'cceb9a1a-bc38-482b-a7d5-82b72a93a298',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([ 1.621103 , 0.41932815, -0.4366651 , ..., 4.0205083 , 3.520041 ,\n", - " 2.6508992 ], dtype=float32),\n", + " 'y': array([-5.745831 , -5.2108707, -4.672284 , ..., -6.2985435, -6.2671404,\n", + " -6.0654645], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': 'c3664db2-136a-40d7-9103-625733dda176',\n", + " 'uid': 'fc2e7ac5-066d-48a0-8e16-b80235e12bfe',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([-0.3051259 , -0.20342994, 0.05545649, 0.13164242, 0.2857538 ,\n", - " 0.06021553, -0.3720306 , -0.01876787, -0.00064597, 0.04794757],\n", - " dtype=float32),\n", + " 'y': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': 'd28e2fd9-0522-488a-b89a-02b1296bae1a',\n", + " 'uid': '6a963e5b-a63f-45bf-ae7f-1fa296e6f623',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 6.2832564e-01, -2.3228288e-01, 6.6017294e-01, 3.3255139e-01,\n", - " 5.0744390e-01, 1.1816436e-01, -1.5548144e-01, 1.6358766e-01,\n", - " -1.7810777e-01, 3.2371131e-01, 4.4875324e-01, -3.1604797e-01,\n", - " -1.3501082e-03, -9.3391158e-02, 8.1444037e-01, -7.3939008e-01,\n", - " 4.2238832e-01, 5.4563276e-02, 3.5837620e-01, -5.2361876e-02,\n", - " -5.4710191e-01, -7.3065239e-01, -3.4761795e-01, 4.7559822e-01,\n", - " 2.0330952e-02, 2.5780448e-01, 1.0076398e-01, 3.2984644e-01,\n", - " 2.2101782e-01, 2.5692052e-01, -8.7424242e-01, -8.4744475e-04,\n", - " 3.0343091e-01], dtype=float32),\n", + " 'y': array([-2.3441856e-01, -7.4248093e-01, 1.3433978e-01, 4.7361168e-01,\n", + " 4.8439783e-01, 2.8078523e-01, -1.9517194e-01, 2.2985543e-01,\n", + " 1.5531473e-01, -4.2631316e-01, 5.0868553e-01, 1.1522221e-01,\n", + " -4.8527386e-02, 2.0242128e-01, 4.4463417e-03, -2.3070528e-01,\n", + " 1.7045366e-02, -8.4169136e-05, 1.5831508e-01, -2.2444238e-01,\n", + " 1.4253077e-01, -2.9090768e-02, -1.4969027e-01, 3.8341036e-01,\n", + " -1.2710637e-01, -1.4844303e-01, 1.1406808e-01, -2.2177878e-01,\n", + " 2.8057915e-01, -3.3217099e-01, -1.6262497e-01, -3.2851827e-01,\n", + " -1.3853197e-01], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': 'ec08d1ed-660d-43c4-8112-e6bd03345ae9',\n", + " 'uid': 'f045a98c-7b21-4761-b615-bc9edab8575b',\n", " 'width': 0.8,\n", - " 'x': array(['Memorial Day_+0', 'Memorial Day_+1', 'Memorial Day_-1',\n", - " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", - " \"Washington's Birthday_-1\", 'Columbus Day_+0', 'Columbus Day_+1',\n", - " 'Columbus Day_-1', 'Martin Luther King Jr. Day_+0',\n", - " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", - " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " 'Thanksgiving_+0', 'Thanksgiving_+1', 'Thanksgiving_-1',\n", - " 'Christmas Day_+0', 'Christmas Day_+1', 'Christmas Day_-1',\n", + " 'x': array([\"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", + " \"Washington's Birthday_-1\", 'Thanksgiving_+0', 'Thanksgiving_+1',\n", + " 'Thanksgiving_-1', 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1',\n", " 'Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", + " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", " 'Independence Day_+0', 'Independence Day_+1', 'Independence Day_-1',\n", - " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1'], dtype=object),\n", + " 'Martin Luther King Jr. Day_+0', 'Martin Luther King Jr. Day_+1',\n", + " 'Martin Luther King Jr. Day_-1', 'Memorial Day_+0', 'Memorial Day_+1',\n", + " 'Memorial Day_-1', 'Columbus Day_+0', 'Columbus Day_+1',\n", + " 'Columbus Day_-1', 'Christmas Day_+0', 'Christmas Day_+1',\n", + " 'Christmas Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [2.815302848815918, -1.2508420944213867, -2.8079898357391357,\n", - " 1.0754282474517822, 1.3034065961837769, 3.282367706298828,\n", - " -0.20270073413848877, 5.353240489959717, 3.5992088317871094,\n", - " 1.0151381492614746, 2.7972025871276855, 0.7658103108406067,\n", - " 0.6110072731971741, -3.8502631187438965, -1.1796778440475464,\n", - " -1.184556484222412, -1.5890964269638062, 6.425841331481934,\n", - " -0.3609931468963623, 5.461928367614746, -3.2548437118530273,\n", - " -7.143401622772217, 4.949648857116699, -2.3242950439453125,\n", - " -3.5402631759643555, 1.2076290845870972, 5.756880283355713,\n", - " -1.5092906951904297, -2.149829149246216, 2.4397971630096436],\n", + " 'y': [-5.999994277954102, 0.15511037409305573, -0.6804019212722778,\n", + " -0.8969926834106445, -4.350093841552734, 2.10798978805542,\n", + " -4.097671031951904, 5.030608177185059, 3.1227762699127197,\n", + " 3.44264817237854, 4.6125640869140625, 0.7293226718902588,\n", + " 2.7135281562805176, -0.2420026659965515, 4.3908257484436035,\n", + " -7.856958866119385, -5.952345848083496, 5.613704204559326,\n", + " -7.10869026184082, -2.1775100231170654, 2.4739584922790527,\n", + " 0.04653293639421463, 1.881555438041687, -0.2442491501569748,\n", + " 1.7328133583068848, 3.332047462463379, -4.845430850982666,\n", + " 0.990510880947113, 3.7318599224090576, -1.215183973312378],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n", From c7c631302280c1de3ede10b27f72abe970841fb3 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 17:47:46 -0700 Subject: [PATCH 30/39] large changeset - isolate lr-finder --- neuralprophet/configure.py | 25 --- neuralprophet/data/process.py | 2 +- neuralprophet/df_utils.py | 2 +- neuralprophet/forecaster.py | 200 ++++++++++---------- neuralprophet/time_net.py | 4 +- neuralprophet/utils.py | 254 ------------------------- neuralprophet/utils_lightning.py | 309 +++++++++++++++++++++++++++++++ 7 files changed, 410 insertions(+), 386 deletions(-) create mode 100644 neuralprophet/utils_lightning.py diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 659888651..00c32429a 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -122,7 +122,6 @@ class Train: trend_reg_threshold: Optional[Union[bool, float]] = None n_data: int = field(init=False) loss_func_name: str = field(init=False) - lr_finder_args: dict = field(default_factory=dict) pl_trainer_config: dict = field(default_factory=dict) def __post_init__(self): @@ -265,30 +264,6 @@ def set_scheduler(self): self.scheduler, torch.optim.lr_scheduler.LRScheduler ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" - def set_lr_finder_args(self, main_training_epochs: int, batches_per_epoch: int): - """ - Set the lr_finder_args. - This is the range of learning rates to test. - """ - main_training_total_steps = main_training_epochs * batches_per_epoch - # main_training_total_steps is around 1e3 to 1e6 -> num_training 100 to 400 - num_training = 100 + int(np.log10(1 + main_training_total_steps / 1000) * 100) - if batches_per_epoch < num_training: - log.warning( - f"Learning rate finder: The number of batches per epoch ({batches_per_epoch}) is too small than the required number \ - for the learning rate finder ({num_training}). The results might not be optimal." - ) - # num_training = num_batches - self.lr_finder_args.update( - { - "min_lr": 1e-7, - "max_lr": 1e1, - "num_training": num_training, - "early_stop_threshold": None, - "mode": "exponential", - } - ) - def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0): # Ignore type warning of epochs possibly being None (does not work with dataclasses) if reg_start_pct == reg_full_pct: diff --git a/neuralprophet/data/process.py b/neuralprophet/data/process.py index 2958dde49..a5c71a2d7 100644 --- a/neuralprophet/data/process.py +++ b/neuralprophet/data/process.py @@ -612,7 +612,7 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None): ------- TimeDataset """ - df, _, _, _ = df_utils.prep_or_copy_df(df) + # df, _, _, _ = df_utils.prep_or_copy_df(df) return time_dataset.GlobalTimeDataset( df, predict_mode=predict_mode, diff --git a/neuralprophet/df_utils.py b/neuralprophet/df_utils.py index c5f6367a5..3c4e4bfa4 100644 --- a/neuralprophet/df_utils.py +++ b/neuralprophet/df_utils.py @@ -308,7 +308,7 @@ def init_data_params( ShiftScale entries containing ``shift`` and ``scale`` parameters for each column """ # Compute Global data params - df, _, _, _ = prep_or_copy_df(df) + # df, _, _, _ = prep_or_copy_df(df) df_merged = df.copy(deep=True).drop("ID", axis=1) global_data_params = data_params_definition( df_merged, normalize, config_lagged_regressors, config_regressors, config_events, config_seasonality diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 33beb416c..1988d0c69 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -16,7 +16,7 @@ from pytorch_lightning.tuner.tuning import Tuner from torch.utils.data import DataLoader -from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_metrics +from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_lightning, utils_metrics from neuralprophet.data.process import ( _check_dataframe, _convert_raw_predictions_to_raw_df, @@ -1292,9 +1292,12 @@ def test(self, df: pd.DataFrame, verbose: bool = True): config_seasonality=self.config_seasonality, predicting=False, ) - loader = self._init_val_loader(df) + # df, _, _, _ = df_utils.prep_or_copy_df(df) + df = _normalize(df=df, config_normalization=self.config_normalization) + dataset = _create_dataset(self, df, predict_mode=False) + test_loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False) # Use Lightning to calculate metrics - val_metrics = self.trainer.test(self.model, dataloaders=loader, verbose=verbose) + val_metrics = self.trainer.test(self.model, dataloaders=test_loader, verbose=verbose) val_metrics_df = pd.DataFrame(val_metrics) # TODO Check whether supported by Lightning if not self.config_normalization.global_normalization: @@ -2660,7 +2663,7 @@ def plot_parameters( def _init_model(self): """Build Pytorch model with configured hyperparamters.""" - self.model = time_net.TimeNet( + model = time_net.TimeNet( config_model=self.config_model, config_train=self.config_train, config_trend=self.config_trend, @@ -2683,9 +2686,10 @@ def _init_model(self): num_seasonalities_modelled_dict=self.num_seasonalities_modelled_dict, meta_used_in_model=self.meta_used_in_model, ) - log.debug(self.model) + log.debug(model) + return model - def _init_train_loader(self, df, num_workers=0): + def _data_setup(self, df): """Executes data preparation steps and initiates training procedure. Parameters @@ -2699,8 +2703,10 @@ def _init_train_loader(self, df, num_workers=0): ------- torch DataLoader """ - df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided? + # df, _, _, _ = df_utils.prep_or_copy_df(df) + if not self.fitted: + # Initialize data normalization parameters self.config_normalization.init_data_params( df=df, config_lagged_regressors=self.config_lagged_regressors, @@ -2709,56 +2715,26 @@ def _init_train_loader(self, df, num_workers=0): config_seasonality=self.config_seasonality, ) - df = _normalize(df=df, config_normalization=self.config_normalization) if not self.fitted: + # scale user-specified changepoint times if self.config_trend.changepoints is not None: - # scale user-specified changepoint times df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) + df_aux = _normalize(df=df_aux, config_normalization=self.config_normalization) + self.config_trend.changepoints = df_aux["t"].values - df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) - self.config_trend.changepoints = df_normalized["t"].values # type: ignore - - # df_merged, _ = df_utils.join_dataframes(df) - # df_merged = df_merged.sort_values("ds") - # df_merged.drop_duplicates(inplace=True, keep="first", subset=["ds"]) - df_merged = df_utils.merge_dataframes(df) - self.config_seasonality = utils.set_auto_seasonalities(df_merged, config_seasonality=self.config_seasonality) - if self.config_country_holidays is not None: - self.config_country_holidays.init_holidays(df_merged) - - dataset = _create_dataset( - self, df, predict_mode=False, prediction_frequency=self.prediction_frequency - ) # needs to be called after set_auto_seasonalities - - # Determine the max_number of epochs - self.config_train.set_auto_batch_epoch(n_data=len(dataset)) - - loader = DataLoader( - dataset, - batch_size=self.config_train.batch_size, - shuffle=True, - num_workers=num_workers, - ) - - return loader - - def _init_val_loader(self, df): - """Executes data preparation steps and initiates evaluation procedure. + # Apply normalization to data + df = _normalize(df=df, config_normalization=self.config_normalization) - Parameters - ---------- - df : pd.DataFrame - dataframe containing column ``ds``, ``y``, and optionally``ID`` with all data + if not self.fitted: + # Temporarily merge df to set auto seasaoanlities and country holidays + df_merged = df_utils.merge_dataframes(df) + self.config_seasonality = utils.set_auto_seasonalities( + df_merged, config_seasonality=self.config_seasonality + ) + if self.config_country_holidays is not None: + self.config_country_holidays.init_holidays(df_merged) - Returns - ------- - torch DataLoader - """ - df, _, _, _ = df_utils.prep_or_copy_df(df) - df = _normalize(df=df, config_normalization=self.config_normalization) - dataset = _create_dataset(self, df, predict_mode=False) - loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False) - return loader + return df def _train( self, @@ -2793,23 +2769,33 @@ def _train( pd.DataFrame metrics """ - # Set up data the training dataloader - df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be removed? - train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(train_loader.dataset) # df - batches_per_epoch = len(train_loader) - log.info(f"Dataset size: {dataset_size}") - log.info(f"Number of batches per training epoch: {batches_per_epoch}") - - # Internal flag to check if validation is enabled - validation_enabled = df_val is not None + # Set up train dataset and data dependent configurations + df = self._data_setup(df) + # Note: _create_dataset() needs to be called after set_auto_seasonalities() + dataset = _create_dataset(self, df, predict_mode=False, prediction_frequency=self.prediction_frequency) + # Determine the max_number of epochs + self.config_train.set_auto_batch_epoch(n_data=len(dataset)) - self._init_model() + # Set up DataLoaders: Train + loader = DataLoader( + dataset, + batch_size=self.config_train.batch_size, + shuffle=True, + num_workers=num_workers, + ) + log.info(f"Train Dataset size: {len(dataset)}") + log.info(f"Number of batches per training epoch: {len(loader)}") - self.model.train_loader = train_loader + # Set up DataLoaders: Validation + validation_enabled = df_val is not None and isinstance(df_val, pd.DataFrame) + if validation_enabled: + # df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) + df_val = _normalize(df=df_val, config_normalization=self.config_normalization) + dataset_val = _create_dataset(self, df_val, predict_mode=False) + loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False) # Init the Trainer - self.trainer, checkpoint_callback = utils.configure_trainer( + self.trainer, checkpoint_callback = utils_lightning.configure_trainer( config_train=self.config_train, metrics_logger=self.metrics_logger, early_stopping_target="Loss_val" if validation_enabled else "Loss", @@ -2817,51 +2803,58 @@ def _train( progress_bar_enabled=progress_bar_enabled, metrics_enabled=metrics_enabled, checkpointing_enabled=checkpointing_enabled, - num_batches_per_epoch=batches_per_epoch, + num_batches_per_epoch=len(loader), deterministic=deterministic, ) - # Find suitable learning rate - if not self.config_train.learning_rate: - log.info("No Learning Rate provided. Activating learning rate finder") - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args( - main_training_epochs=self.config_train.epochs, batches_per_epoch=batches_per_epoch + # Find suitable learning rate if not set + if self.config_train.learning_rate is None: + assert not self.fitted, "Learning rate must be provided for re-training a fitted model." + # Init a separate Model for LR finder (optional, done for safety) + model_lr_finder = self._init_model() + # Init a separate DataLoader for LR finder (optional, done for safety) + loader_lr_finder = DataLoader( + dataset, + batch_size=self.config_train.batch_size, + shuffle=True, + num_workers=num_workers, + ) + # Init a separate Trainer for LR finder (optional, done for safety) + trainer_lr_finder, _ = utils_lightning.configure_trainer( + config_train=self.config_train, + metrics_logger=self.metrics_logger, + early_stopping_target="Loss", + accelerator=self.accelerator, + progress_bar_enabled=progress_bar_enabled, + metrics_enabled=False, + checkpointing_enabled=False, + num_batches_per_epoch=len(loader), + deterministic=deterministic, ) - log.info(f"Learning rate finder ---- ARGs: {self.config_train.lr_finder_args}") - self.model.finding_lr = True - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - # val_dataloaders=val_loader, # not used, but may lead to Lightning bug if not provided - **self.config_train.lr_finder_args, + # Setup and execute LR finder + self.config_train.learning_rate = utils_lightning.find_learning_rate( + model=model_lr_finder, + loader=loader_lr_finder, + trainer=trainer_lr_finder, + train_epochs=self.config_train.epochs, ) - # Estimate the optimal learning rate from the loss curve - assert lr_finder is not None - _, _, lr_suggested = utils.smooth_loss_and_suggest(lr_finder) - self.model.learning_rate = lr_suggested - self.config_train.learning_rate = lr_suggested - log.info(f"Learning rate finder suggested learning rate: {lr_suggested}") - self.model.finding_lr = False - - # Tune hyperparams and train - if validation_enabled: - # Set up data the validation dataloader - df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) - val_loader = self._init_val_loader(df_val) - self.model.finding_lr = False + # Set up the model for training + if not self.fitted: + self.model = self._init_model() + # self.model.train_loader = loader + # self.model.finding_lr = False + + # Execute Training Loop start = time.time() self.trainer.fit( model=self.model, - train_dataloaders=train_loader, - val_dataloaders=val_loader if validation_enabled else None, + train_dataloaders=loader, + val_dataloaders=loader_val if validation_enabled else None, ) + log.info("Train Time: {:8.3f}".format(time.time() - start)) - log.debug("Train Time: {:8.3f}".format(time.time() - start)) - - # Load best model from training + # Load best model from checkpoint if end state not best if checkpoint_callback is not None: if checkpoint_callback.best_model_score < checkpoint_callback.current_score: log.info( @@ -2870,11 +2863,12 @@ def _train( ) self.model = time_net.TimeNet.load_from_checkpoint(checkpoint_callback.best_model_path) - if not metrics_enabled: - return None + if metrics_enabled: + # Return metrics collected in logger as dataframe + metrics_df = pd.DataFrame(self.metrics_logger.history) + else: + metrics_df = None - # Return metrics collected in logger as dataframe - metrics_df = pd.DataFrame(self.metrics_logger.history) return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): @@ -2886,7 +2880,7 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ Restore the trainer based on the forecaster configuration. """ - self.trainer, _ = utils.configure_trainer( + self.trainer, _ = utils_lightning.configure_trainer( config_train=self.config_train, metrics_logger=self.metrics_logger, accelerator=accelerator, diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index a1148aa4f..32fa72fd9 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -767,7 +767,7 @@ def loss_func(self, inputs, predicted, targets): loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) loss = loss.sum(dim=2).mean() # Regularize. - if self.reg_enabled: + if self.reg_enabled and not self.finding_lr: loss, reg_loss = self._add_batch_regularizations(loss, self.train_progress) else: reg_loss = torch.tensor(0.0, device=self.device) @@ -804,7 +804,7 @@ def training_step(self, batch, batch_idx): if self.finding_lr: # Manually track the loss for the lr finder self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + # self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) # Metrics if self.metrics_enabled and not self.finding_lr: diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 39a82f84f..10fa63f43 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -749,257 +749,3 @@ def set_log_level(log_level: str = "INFO", include_handlers: bool = False): >>> set_log_level("ERROR") """ set_logger_level(logging.getLogger("NP"), log_level, include_handlers) - - -def smooth_loss_and_suggest(lr_finder, window=10): - """ - Smooth loss using a Hamming filter. - - Parameters - ---------- - loss : np.array - Loss values - - Returns - ------- - loss_smoothed : np.array - Smoothed loss values - lr: np.array - Learning rate values - suggested_lr: float - Suggested learning rate based on gradient - """ - lr_finder_results = lr_finder.results - lr = lr_finder_results["lr"] - loss = np.array(lr_finder_results["loss"]) - # Derive window size from num lr searches, ensure window is divisible by 2 - # half_window = math.ceil(round(len(loss) * 0.1) / 2) - half_window = math.ceil(window / 2) - # Pad sequence and initialialize hamming filter - loss = np.pad(loss, pad_width=half_window, mode="edge") - hamming_window = np.hamming(2 * half_window) - # Convolve the over the loss distribution - try: - loss_smooth = np.convolve( - hamming_window / hamming_window.sum(), - loss, - mode="valid", - )[1:] - except ValueError: - log.warning( - f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " - f"{window}." - ) - - # Suggest the lr with steepest negative gradient - try: - # Find the steepest gradient and the minimum loss after that - suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] - suggestion_minimum = lr[np.argmin(np.array(lr_finder_results["loss"]))] - except ValueError: - log.error( - f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " - "samples or manually set the learning rate." - ) - raise - # get the tuner's default suggestion - suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10) - - log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") - log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") - log.info(f"Learning rate finder ---- minimum (not used): {suggestion_minimum}") - if suggestion_steepest is not None and suggestion_default is not None: - log_suggestion_smooth = np.log(suggestion_steepest) - log_suggestion_default = np.log(suggestion_default) - lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) - log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") - elif suggestion_steepest is None and suggestion_default is None: - log.error("Automatic learning rate test failed. Please set manually the learning rate.") - raise - else: - lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default - - log.info(f"Learning rate finder ---- returning: {lr_suggestion}") - log.info(f"Learning rate finder ---- LR (start): {lr[:5]}") - log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}") - log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}") - log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}") - return (loss, lr, lr_suggestion) - - -def _smooth_loss(loss, beta=0.9): - smoothed_loss = np.zeros_like(loss) - smoothed_loss[0] = loss[0] - for i in range(1, len(loss)): - smoothed_loss[i] = smoothed_loss[i - 1] * beta + (1 - beta) * loss[i] - return smoothed_loss - - -def configure_trainer( - config_train: Train, - metrics_logger, - early_stopping_target: str = "Loss", - accelerator: Optional[str] = None, - progress_bar_enabled: bool = True, - metrics_enabled: bool = False, - checkpointing_enabled: bool = False, - num_batches_per_epoch: int = 100, - deterministic: bool = False, -): - """ - Configures the PyTorch Lightning trainer. - - Parameters - ---------- - config_train : Dict - dictionary containing the overall training configuration. - metrics_logger : MetricsLogger - MetricsLogger object to log metrics to. - early_stopping_target : str - Target metric to use for early stopping. - accelerator : str - Accelerator to use for training. - progress_bar_enabled : bool - If False, no progress bar is shown. - metrics_enabled : bool - If False, no metrics are logged. Calculating metrics is computationally expensive and reduces the training - speed. - checkpointing_enabled : bool - If False, no checkpointing is performed. Checkpointing reduces the training speed. - num_batches_per_epoch : int - Number of batches per epoch. - - Returns - ------- - pl.Trainer - PyTorch Lightning trainer - checkpoint_callback - PyTorch Lightning checkpoint callback to load the best model - """ - if config_train.pl_trainer_config is None: - config_train.pl_trainer_config = {} - - pl_trainer_config = config_train.pl_trainer_config - # pl_trainer_config = pl_trainer_config.copy() - - # Set max number of epochs - if hasattr(config_train, "epochs"): - if config_train.epochs is not None: - pl_trainer_config["max_epochs"] = config_train.epochs - - # Configure the Ligthing-logs directory - if "default_root_dir" not in pl_trainer_config.keys(): - pl_trainer_config["default_root_dir"] = os.getcwd() - - # Accelerator - if isinstance(accelerator, str): - if (accelerator == "auto" and torch.cuda.is_available()) or accelerator == "gpu": - pl_trainer_config["accelerator"] = "gpu" - pl_trainer_config["devices"] = -1 - elif (accelerator == "auto" and hasattr(torch.backends, "mps")) or accelerator == "mps": - if torch.backends.mps.is_available(): - pl_trainer_config["accelerator"] = "mps" - pl_trainer_config["devices"] = 1 - elif accelerator != "auto": - pl_trainer_config["accelerator"] = accelerator - pl_trainer_config["devices"] = 1 - - if "accelerator" in pl_trainer_config: - log.info( - f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s)." - ) - else: - log.info("No accelerator available. Using CPU for training.") - - # Configure metrics - if metrics_enabled: - pl_trainer_config["logger"] = metrics_logger - else: - pl_trainer_config["logger"] = False - - pl_trainer_config["deterministic"] = deterministic - - # Configure callbacks - callbacks = [] - has_custom_callbacks = True if "callbacks" in pl_trainer_config else False - - # Configure checkpointing - has_modelcheckpoint_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in pl_trainer_config["callbacks"]) - else False - ) - if has_modelcheckpoint_callback and not checkpointing_enabled: - raise ValueError( - "Checkpointing is disabled but a ModelCheckpoint callback is provided. Please enable checkpointing or " - "remove the callback." - ) - if checkpointing_enabled: - if not has_modelcheckpoint_callback: - # Callback to access both the last and best model - checkpoint_callback = pl.callbacks.ModelCheckpoint( - monitor=early_stopping_target, mode="min", save_top_k=1, save_last=True - ) - callbacks.append(checkpoint_callback) - else: - checkpoint_callback = next( - callback - for callback in pl_trainer_config["callbacks"] - if isinstance(callback, pl.callbacks.ModelCheckpoint) - ) - else: - pl_trainer_config["enable_checkpointing"] = False - checkpoint_callback = None - - # Configure the progress bar, refresh every epoch - has_progressbar_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in pl_trainer_config["callbacks"]) - else False - ) - if has_progressbar_callback and not progress_bar_enabled: - raise ValueError( - "Progress bar is disabled but a ProgressBar callback is provided. Please enable the progress bar or remove" - " the callback." - ) - if progress_bar_enabled: - if not has_progressbar_callback: - prog_bar_callback = ProgressBar(refresh_rate=num_batches_per_epoch, epochs=config_train.epochs) - callbacks.append(prog_bar_callback) - else: - pl_trainer_config["enable_progress_bar"] = False - - # Early stopping monitor - has_earlystopping_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in pl_trainer_config["callbacks"]) - else False - ) - if has_earlystopping_callback and not config_train.early_stopping: - raise ValueError( - "Early stopping is disabled but an EarlyStopping callback is provided. Please enable early stopping or " - "remove the callback." - ) - if config_train.early_stopping: - if not metrics_enabled: - raise ValueError("Early stopping requires metrics to be enabled.") - if not has_earlystopping_callback: - early_stop_callback = pl.callbacks.EarlyStopping( - monitor=early_stopping_target, mode="min", patience=20, divergence_threshold=5.0 - ) - callbacks.append(early_stop_callback) - - if has_custom_callbacks: - pl_trainer_config["callbacks"].extend(callbacks) - else: - pl_trainer_config["callbacks"] = callbacks - pl_trainer_config["num_sanity_val_steps"] = 0 - pl_trainer_config["enable_model_summary"] = False - # TODO: Disabling sampler_ddp brings a good speedup in performance, however, check whether this is a good idea - # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#replace-sampler-ddp - # config["replace_sampler_ddp"] = False - - return pl.Trainer(**pl_trainer_config), checkpoint_callback diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py new file mode 100644 index 000000000..9091dc727 --- /dev/null +++ b/neuralprophet/utils_lightning.py @@ -0,0 +1,309 @@ +import logging +import math +import os +from typing import Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ProgressBar +from torch.utils.data import DataLoader + +from neuralprophet.configure import Train + +log = logging.getLogger("NP.utils_lightning") + + +def smooth_loss_and_suggest(lr_finder, window=10): + """ + Smooth loss using a Hamming filter. + + Parameters + ---------- + loss : np.array + Loss values + + Returns + ------- + loss_smoothed : np.array + Smoothed loss values + lr: np.array + Learning rate values + suggested_lr: float + Suggested learning rate based on gradient + """ + lr_finder_results = lr_finder.results + lr = lr_finder_results["lr"] + loss = np.array(lr_finder_results["loss"]) + # Derive window size from num lr searches, ensure window is divisible by 2 + # half_window = math.ceil(round(len(loss) * 0.1) / 2) + half_window = math.ceil(window / 2) + # Pad sequence and initialialize hamming filter + loss = np.pad(loss, pad_width=half_window, mode="edge") + hamming_window = np.hamming(2 * half_window) + # Convolve the over the loss distribution + try: + loss_smooth = np.convolve( + hamming_window / hamming_window.sum(), + loss, + mode="valid", + )[1:] + except ValueError: + log.warning( + f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " + f"{window}." + ) + + # Suggest the lr with steepest negative gradient + try: + # Find the steepest gradient and the minimum loss after that + suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] + suggestion_minimum = lr[np.argmin(np.array(lr_finder_results["loss"]))] + except ValueError: + log.error( + f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " + "samples or manually set the learning rate." + ) + raise + # get the tuner's default suggestion + suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10) + + log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") + log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") + log.info(f"Learning rate finder ---- minimum (not used): {suggestion_minimum}") + if suggestion_steepest is not None and suggestion_default is not None: + log_suggestion_smooth = np.log(suggestion_steepest) + log_suggestion_default = np.log(suggestion_default) + lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) + log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") + elif suggestion_steepest is None and suggestion_default is None: + log.error("Automatic learning rate test failed. Please set manually the learning rate.") + raise + else: + lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default + + log.info(f"Learning rate finder ---- returning: {lr_suggestion}") + log.info(f"Learning rate finder ---- LR (start): {lr[:5]}") + log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}") + log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}") + log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}") + return loss, lr, lr_suggestion + + +def _smooth_loss(loss, beta=0.9): + smoothed_loss = np.zeros_like(loss) + smoothed_loss[0] = loss[0] + for i in range(1, len(loss)): + smoothed_loss[i] = smoothed_loss[i - 1] * beta + (1 - beta) * loss[i] + return smoothed_loss + + +def configure_trainer( + config_train: Train, + metrics_logger, + early_stopping_target: str = "Loss", + accelerator: Optional[str] = None, + progress_bar_enabled: bool = True, + metrics_enabled: bool = False, + checkpointing_enabled: bool = False, + num_batches_per_epoch: int = 100, + deterministic: bool = False, +): + """ + Configures the PyTorch Lightning trainer. + + Parameters + ---------- + config_train : Dict + dictionary containing the overall training configuration. + metrics_logger : MetricsLogger + MetricsLogger object to log metrics to. + early_stopping_target : str + Target metric to use for early stopping. + accelerator : str + Accelerator to use for training. + progress_bar_enabled : bool + If False, no progress bar is shown. + metrics_enabled : bool + If False, no metrics are logged. Calculating metrics is computationally expensive and reduces the training + speed. + checkpointing_enabled : bool + If False, no checkpointing is performed. Checkpointing reduces the training speed. + num_batches_per_epoch : int + Number of batches per epoch. + + Returns + ------- + pl.Trainer + PyTorch Lightning trainer + checkpoint_callback + PyTorch Lightning checkpoint callback to load the best model + """ + if config_train.pl_trainer_config is None: + config_train.pl_trainer_config = {} + + pl_trainer_config = config_train.pl_trainer_config + # pl_trainer_config = pl_trainer_config.copy() + + # Set max number of epochs + assert hasattr(config_train, "epochs") and config_train.epochs is not None + pl_trainer_config["max_epochs"] = config_train.epochs + + # Configure the Ligthing-logs directory + if "default_root_dir" not in pl_trainer_config.keys(): + pl_trainer_config["default_root_dir"] = os.getcwd() + + # Accelerator + if isinstance(accelerator, str): + if (accelerator == "auto" and torch.cuda.is_available()) or accelerator == "gpu": + pl_trainer_config["accelerator"] = "gpu" + pl_trainer_config["devices"] = -1 + elif (accelerator == "auto" and hasattr(torch.backends, "mps")) or accelerator == "mps": + if torch.backends.mps.is_available(): + pl_trainer_config["accelerator"] = "mps" + pl_trainer_config["devices"] = 1 + elif accelerator != "auto": + pl_trainer_config["accelerator"] = accelerator + pl_trainer_config["devices"] = 1 + + if "accelerator" in pl_trainer_config: + log.info( + f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s)." + ) + else: + log.info("No accelerator available. Using CPU for training.") + + # Configure metrics + if metrics_enabled: + pl_trainer_config["logger"] = metrics_logger + else: + pl_trainer_config["logger"] = False + + pl_trainer_config["deterministic"] = deterministic + + # Configure callbacks + callbacks = [] + has_custom_callbacks = True if "callbacks" in pl_trainer_config else False + + # Configure checkpointing + has_modelcheckpoint_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_modelcheckpoint_callback and not checkpointing_enabled: + raise ValueError( + "Checkpointing is disabled but a ModelCheckpoint callback is provided. Please enable checkpointing or " + "remove the callback." + ) + if checkpointing_enabled: + if not has_modelcheckpoint_callback: + # Callback to access both the last and best model + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor=early_stopping_target, mode="min", save_top_k=1, save_last=True + ) + callbacks.append(checkpoint_callback) + else: + checkpoint_callback = next( + callback + for callback in pl_trainer_config["callbacks"] + if isinstance(callback, pl.callbacks.ModelCheckpoint) + ) + else: + pl_trainer_config["enable_checkpointing"] = False + checkpoint_callback = None + + # Configure the progress bar, refresh every epoch + has_progressbar_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, ProgressBar) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_progressbar_callback and not progress_bar_enabled: + raise ValueError( + "Progress bar is disabled but a ProgressBar callback is provided. Please enable the progress bar or remove" + " the callback." + ) + if progress_bar_enabled: + if not has_progressbar_callback: + prog_bar_callback = ProgressBar(refresh_rate=num_batches_per_epoch, epochs=config_train.epochs) + callbacks.append(prog_bar_callback) + else: + pl_trainer_config["enable_progress_bar"] = False + + # Early stopping monitor + has_earlystopping_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_earlystopping_callback and not config_train.early_stopping: + raise ValueError( + "Early stopping is disabled but an EarlyStopping callback is provided. Please enable early stopping or " + "remove the callback." + ) + if config_train.early_stopping: + if not metrics_enabled: + raise ValueError("Early stopping requires metrics to be enabled.") + if not has_earlystopping_callback: + early_stop_callback = pl.callbacks.EarlyStopping( + monitor=early_stopping_target, mode="min", patience=20, divergence_threshold=5.0 + ) + callbacks.append(early_stop_callback) + + if has_custom_callbacks: + pl_trainer_config["callbacks"].extend(callbacks) + else: + pl_trainer_config["callbacks"] = callbacks + pl_trainer_config["num_sanity_val_steps"] = 0 + pl_trainer_config["enable_model_summary"] = False + # TODO: Disabling sampler_ddp brings a good speedup in performance, however, check whether this is a good idea + # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#replace-sampler-ddp + # config["replace_sampler_ddp"] = False + + return pl.Trainer(**pl_trainer_config), checkpoint_callback + + +def find_learning_rate(model, loader, trainer, train_epochs): + log.info("No Learning Rate provided. Activating learning rate finder") + + # Configure the learning rate finder args + batches_per_epoch = len(loader) + main_training_total_steps = train_epochs * batches_per_epoch + # main_training_total_steps is around 1e3 to 1e6 -> num_training 100 to 400 + num_training = 100 + int(np.log10(1 + main_training_total_steps / 1000) * 100) + if batches_per_epoch < num_training: + log.warning( + f"Learning rate finder: The number of batches per epoch ({batches_per_epoch}) is too small than the required number \ + for the learning rate finder ({num_training}). The results might not be optimal." + ) + # num_training = num_batches + lr_finder_args = { + "min_lr": 1e-7, + "max_lr": 1e1, + "num_training": num_training, + "early_stop_threshold": None, + "mode": "exponential", + } + log.info(f"Learning rate finder ---- ARGs: {lr_finder_args}") + + # Execute the learning rate range finder + tuner = pl.Tuner(trainer) + model.finding_lr = True + # model.train_loader = loader + lr_finder = tuner.lr_find( + model=model, + train_dataloaders=loader, + # val_dataloaders=val_loader, # not used, but lead to Lightning bug if not provided in prior versions. + **lr_finder_args, + ) + model.finding_lr = False + + # Estimate the optimal learning rate from the loss curve + assert lr_finder is not None + loss_list, lr_list, lr_suggested = smooth_loss_and_suggest(lr_finder) + log.info(f"Learning rate finder suggested learning rate: {lr_suggested}") + return lr_suggested From ee28441a65e29a7c47f94979d69c624414fcb28e Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 17:54:36 -0700 Subject: [PATCH 31/39] fix progressbar --- neuralprophet/utils_lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py index 9091dc727..826840e1f 100644 --- a/neuralprophet/utils_lightning.py +++ b/neuralprophet/utils_lightning.py @@ -6,10 +6,9 @@ import numpy as np import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import ProgressBar -from torch.utils.data import DataLoader from neuralprophet.configure import Train +from neuralprophet.logger import ProgressBar log = logging.getLogger("NP.utils_lightning") @@ -218,7 +217,7 @@ def configure_trainer( has_progressbar_callback = ( True if has_custom_callbacks - and any(isinstance(callback, ProgressBar) for callback in pl_trainer_config["callbacks"]) + and any(isinstance(callback, pl.callback.ProgressBar) for callback in pl_trainer_config["callbacks"]) else False ) if has_progressbar_callback and not progress_bar_enabled: From be3b6cfd4e4675cb3d3027e3791716fb2bf03c4c Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:04:38 -0700 Subject: [PATCH 32/39] remove dataloader from model --- neuralprophet/configure.py | 3 +++ neuralprophet/forecaster.py | 11 +++++++---- neuralprophet/time_net.py | 6 +++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 00c32429a..e7c5a5328 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -278,6 +278,9 @@ def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_p delay_weight = 1 return delay_weight + def set_batches_per_epoch(self, batches_per_epoch: int): + self.batches_per_epoch = batches_per_epoch + @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 1988d0c69..4cf8fe974 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2783,6 +2783,7 @@ def _train( shuffle=True, num_workers=num_workers, ) + self.config_train.set_batches_per_epoch(len(loader)) log.info(f"Train Dataset size: {len(dataset)}") log.info(f"Number of batches per training epoch: {len(loader)}") @@ -2810,16 +2811,14 @@ def _train( # Find suitable learning rate if not set if self.config_train.learning_rate is None: assert not self.fitted, "Learning rate must be provided for re-training a fitted model." - # Init a separate Model for LR finder (optional, done for safety) + # Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety) model_lr_finder = self._init_model() - # Init a separate DataLoader for LR finder (optional, done for safety) loader_lr_finder = DataLoader( dataset, batch_size=self.config_train.batch_size, shuffle=True, num_workers=num_workers, ) - # Init a separate Trainer for LR finder (optional, done for safety) trainer_lr_finder, _ = utils_lightning.configure_trainer( config_train=self.config_train, metrics_logger=self.metrics_logger, @@ -2832,12 +2831,16 @@ def _train( deterministic=deterministic, ) # Setup and execute LR finder - self.config_train.learning_rate = utils_lightning.find_learning_rate( + suggested_lr = utils_lightning.find_learning_rate( model=model_lr_finder, loader=loader_lr_finder, trainer=trainer_lr_finder, train_epochs=self.config_train.epochs, ) + # Save the suggested learning rate + self.config_train.learning_rate = suggested_lr + # Clean up the LR finder copies of Model, Loader and Trainer + del model_lr_finder, loader_lr_finder, trainer_lr_finder # Set up the model for training if not self.fitted: diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 32fa72fd9..1097c8e57 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -869,7 +869,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): - self.train_steps_per_epoch = len(self.train_loader) + self.train_steps_per_epoch = self.config_train.batches_per_epoch # self.trainer.num_training_batches = self.train_steps_per_epoch * self.config_train.epochs self.config_train.set_optimizer() @@ -1010,8 +1010,8 @@ def denormalize(self, ts): ts = scale_y * ts + shift_y return ts - def train_dataloader(self): - return self.train_loader + # def train_dataloader(self): + # return self.train_loader class FlatNet(nn.Module): From 462f80fcd699759ada8bf2f17b023d61d4200cc3 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:07:46 -0700 Subject: [PATCH 33/39] fix callbacks ProgressBar --- neuralprophet/utils_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py index 826840e1f..4ad102b2d 100644 --- a/neuralprophet/utils_lightning.py +++ b/neuralprophet/utils_lightning.py @@ -217,7 +217,7 @@ def configure_trainer( has_progressbar_callback = ( True if has_custom_callbacks - and any(isinstance(callback, pl.callback.ProgressBar) for callback in pl_trainer_config["callbacks"]) + and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in pl_trainer_config["callbacks"]) else False ) if has_progressbar_callback and not progress_bar_enabled: From 628b4ad6aa1dd267ea1b177719ebf2b0f419db84 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:17:58 -0700 Subject: [PATCH 34/39] fixing tuner --- tests/test_glocal.py | 8 +++++--- tests/test_integration.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/test_glocal.py b/tests/test_glocal.py index 9bda1882c..fe4719140 100644 --- a/tests/test_glocal.py +++ b/tests/test_glocal.py @@ -20,7 +20,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 256 EPOCHS = 1 -BATCH_SIZE = 128 +BATCH_SIZE = 32 LR = 1.0 PLOT = False @@ -60,7 +60,7 @@ def test_regularized_trend_global_local_modeling(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - m = NeuralProphet(n_lags=10, epochs=EPOCHS, trend_global_local="local", trend_reg=1) + m = NeuralProphet(n_lags=10, epochs=EPOCHS, learning_rate=LR, trend_global_local="local", trend_reg=1) train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) m.fit(train_df) future = m.make_future_dataframe(test_df) @@ -286,7 +286,9 @@ def test_adding_new_local_seasonality(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, season_global_local="global", trend_global_local="local") + m = NeuralProphet( + epochs=EPOCHS, learning_rate=LR, batch_size=BATCH_SIZE, season_global_local="global", trend_global_local="local" + ) m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="local") train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) m.fit(train_df) diff --git a/tests/test_integration.py b/tests/test_integration.py index 8ef45b10a..ee4a9028d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -611,6 +611,20 @@ def test_loss_func_torch(): m.predict(future) +def test_loss_func_torch_lr_finder(): + log.info("TEST setting torch.nn loss func") + df = pd.read_csv(PEYTON_FILE, nrows=512) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + loss_func=torch.nn.MSELoss, + learning_rate=None, + ) + m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=10, n_historic_predictions=10) + m.predict(future) + + def test_callable_loss(): log.info("TEST Callable Loss") @@ -630,6 +644,7 @@ def my_loss(output, target): batch_size=BATCH_SIZE, seasonality_mode="multiplicative", loss_func=my_loss, + learning_rate=LR, ) m.fit(df, freq="5min") future = m.make_future_dataframe(df, periods=12 * 24, n_historic_predictions=12 * 24) @@ -659,6 +674,7 @@ def forward(self, input, target): epochs=EPOCHS, batch_size=BATCH_SIZE, loss_func=MyLoss, + learning_rate=LR, ) m.fit(df, freq="5min") future = m.make_future_dataframe(df, periods=12, n_historic_predictions=12) From 4b9b30568e86e73fc456295bfab2a5c64103f09e Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:18:59 -0700 Subject: [PATCH 35/39] fix tuner --- neuralprophet/utils_lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py index 4ad102b2d..2e7bfc7e9 100644 --- a/neuralprophet/utils_lightning.py +++ b/neuralprophet/utils_lightning.py @@ -6,6 +6,7 @@ import numpy as np import pytorch_lightning as pl import torch +from pytorch_lightning.tuner.tuning import Tuner from neuralprophet.configure import Train from neuralprophet.logger import ProgressBar @@ -290,7 +291,7 @@ def find_learning_rate(model, loader, trainer, train_epochs): log.info(f"Learning rate finder ---- ARGs: {lr_finder_args}") # Execute the learning rate range finder - tuner = pl.Tuner(trainer) + tuner = Tuner(trainer) model.finding_lr = True # model.train_loader = loader lr_finder = tuner.lr_find( From bc27cfc1abf381d029dce16842e042ac8901c0b6 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:41:30 -0700 Subject: [PATCH 36/39] readd prep_or_copy --- neuralprophet/data/process.py | 2 +- neuralprophet/df_utils.py | 2 +- neuralprophet/forecaster.py | 8 ++++---- neuralprophet/utils_lightning.py | 10 ++++------ 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/neuralprophet/data/process.py b/neuralprophet/data/process.py index a5c71a2d7..2958dde49 100644 --- a/neuralprophet/data/process.py +++ b/neuralprophet/data/process.py @@ -612,7 +612,7 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None): ------- TimeDataset """ - # df, _, _, _ = df_utils.prep_or_copy_df(df) + df, _, _, _ = df_utils.prep_or_copy_df(df) return time_dataset.GlobalTimeDataset( df, predict_mode=predict_mode, diff --git a/neuralprophet/df_utils.py b/neuralprophet/df_utils.py index 3c4e4bfa4..c5f6367a5 100644 --- a/neuralprophet/df_utils.py +++ b/neuralprophet/df_utils.py @@ -308,7 +308,7 @@ def init_data_params( ShiftScale entries containing ``shift`` and ``scale`` parameters for each column """ # Compute Global data params - # df, _, _, _ = prep_or_copy_df(df) + df, _, _, _ = prep_or_copy_df(df) df_merged = df.copy(deep=True).drop("ID", axis=1) global_data_params = data_params_definition( df_merged, normalize, config_lagged_regressors, config_regressors, config_events, config_seasonality diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 4cf8fe974..27efc53da 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1292,7 +1292,7 @@ def test(self, df: pd.DataFrame, verbose: bool = True): config_seasonality=self.config_seasonality, predicting=False, ) - # df, _, _, _ = df_utils.prep_or_copy_df(df) + df, _, _, _ = df_utils.prep_or_copy_df(df) df = _normalize(df=df, config_normalization=self.config_normalization) dataset = _create_dataset(self, df, predict_mode=False) test_loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False) @@ -2703,7 +2703,7 @@ def _data_setup(self, df): ------- torch DataLoader """ - # df, _, _, _ = df_utils.prep_or_copy_df(df) + df, _, _, _ = df_utils.prep_or_copy_df(df) if not self.fitted: # Initialize data normalization parameters @@ -2790,7 +2790,7 @@ def _train( # Set up DataLoaders: Validation validation_enabled = df_val is not None and isinstance(df_val, pd.DataFrame) if validation_enabled: - # df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) + df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) df_val = _normalize(df=df_val, config_normalization=self.config_normalization) dataset_val = _create_dataset(self, df_val, predict_mode=False) loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False) @@ -2840,7 +2840,7 @@ def _train( # Save the suggested learning rate self.config_train.learning_rate = suggested_lr # Clean up the LR finder copies of Model, Loader and Trainer - del model_lr_finder, loader_lr_finder, trainer_lr_finder + # del model_lr_finder, loader_lr_finder, trainer_lr_finder # Set up the model for training if not self.fitted: diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py index 2e7bfc7e9..2ecdb3eb3 100644 --- a/neuralprophet/utils_lightning.py +++ b/neuralprophet/utils_lightning.py @@ -166,12 +166,10 @@ def configure_trainer( pl_trainer_config["accelerator"] = accelerator pl_trainer_config["devices"] = 1 - if "accelerator" in pl_trainer_config: - log.info( - f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s)." - ) - else: - log.info("No accelerator available. Using CPU for training.") + if "accelerator" in pl_trainer_config: + log.info(f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s).") + elif accelerator == "auto": + log.info("No accelerator available. Using CPU for training.") # Configure metrics if metrics_enabled: From 56215b7acc7c442776ae85f045f625ea6c769b23 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:48:44 -0700 Subject: [PATCH 37/39] undo copy of model, loader, trainer --- neuralprophet/forecaster.py | 54 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 27efc53da..303a945e8 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1,9 +1,7 @@ import logging -import math import os import time from collections import OrderedDict -from dataclasses import dataclass, field from typing import Callable, List, Optional, Tuple, Type, Union import matplotlib @@ -13,7 +11,6 @@ import torch from matplotlib import pyplot from matplotlib.axes import Axes -from pytorch_lightning.tuner.tuning import Tuner from torch.utils.data import DataLoader from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_lightning, utils_metrics @@ -2808,28 +2805,37 @@ def _train( deterministic=deterministic, ) + # Set up the model for training + if not self.fitted: + self.model = self._init_model() + # self.model.train_loader = loader + # self.model.finding_lr = False + # Find suitable learning rate if not set if self.config_train.learning_rate is None: assert not self.fitted, "Learning rate must be provided for re-training a fitted model." + model_lr_finder = self.model + loader_lr_finder = loader + trainer_lr_finder = self.trainer # Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety) - model_lr_finder = self._init_model() - loader_lr_finder = DataLoader( - dataset, - batch_size=self.config_train.batch_size, - shuffle=True, - num_workers=num_workers, - ) - trainer_lr_finder, _ = utils_lightning.configure_trainer( - config_train=self.config_train, - metrics_logger=self.metrics_logger, - early_stopping_target="Loss", - accelerator=self.accelerator, - progress_bar_enabled=progress_bar_enabled, - metrics_enabled=False, - checkpointing_enabled=False, - num_batches_per_epoch=len(loader), - deterministic=deterministic, - ) + # model_lr_finder = self._init_model() + # loader_lr_finder = DataLoader( + # dataset, + # batch_size=self.config_train.batch_size, + # shuffle=True, + # num_workers=num_workers, + # ) + # trainer_lr_finder, _ = utils_lightning.configure_trainer( + # config_train=self.config_train, + # metrics_logger=self.metrics_logger, + # early_stopping_target="Loss", + # accelerator=self.accelerator, + # progress_bar_enabled=progress_bar_enabled, + # metrics_enabled=False, + # checkpointing_enabled=False, + # num_batches_per_epoch=len(loader), + # deterministic=deterministic, + # ) # Setup and execute LR finder suggested_lr = utils_lightning.find_learning_rate( model=model_lr_finder, @@ -2842,12 +2848,6 @@ def _train( # Clean up the LR finder copies of Model, Loader and Trainer # del model_lr_finder, loader_lr_finder, trainer_lr_finder - # Set up the model for training - if not self.fitted: - self.model = self._init_model() - # self.model.train_loader = loader - # self.model.finding_lr = False - # Execute Training Loop start = time.time() self.trainer.fit( From fda417aeef8ce22e27b44b0c276667d9bfa4af35 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:50:01 -0700 Subject: [PATCH 38/39] add comment about separate lr finder copies --- neuralprophet/forecaster.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 303a945e8..0a054c5c0 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2817,7 +2817,9 @@ def _train( model_lr_finder = self.model loader_lr_finder = loader trainer_lr_finder = self.trainer + # Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety) + # Note Leads to a CUDA issue. Needs to be fixed before enabling this feature. # model_lr_finder = self._init_model() # loader_lr_finder = DataLoader( # dataset, From 70202444d0b51e7eafec526ba0c91c5707f49114 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 29 Aug 2024 18:53:17 -0700 Subject: [PATCH 39/39] improve lr finder comment --- neuralprophet/forecaster.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 0a054c5c0..51f818377 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2808,15 +2808,10 @@ def _train( # Set up the model for training if not self.fitted: self.model = self._init_model() - # self.model.train_loader = loader - # self.model.finding_lr = False # Find suitable learning rate if not set if self.config_train.learning_rate is None: assert not self.fitted, "Learning rate must be provided for re-training a fitted model." - model_lr_finder = self.model - loader_lr_finder = loader - trainer_lr_finder = self.trainer # Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety) # Note Leads to a CUDA issue. Needs to be fixed before enabling this feature. @@ -2838,18 +2833,21 @@ def _train( # num_batches_per_epoch=len(loader), # deterministic=deterministic, # ) + # Setup and execute LR finder suggested_lr = utils_lightning.find_learning_rate( - model=model_lr_finder, - loader=loader_lr_finder, - trainer=trainer_lr_finder, + model=self.model, # model_lr_finder, + loader=loader, # loader_lr_finder, + trainer=self.trainer, # trainer_lr_finder, train_epochs=self.config_train.epochs, ) - # Save the suggested learning rate - self.config_train.learning_rate = suggested_lr # Clean up the LR finder copies of Model, Loader and Trainer # del model_lr_finder, loader_lr_finder, trainer_lr_finder + # Save the suggested learning rate + self.config_train.learning_rate = suggested_lr + self.model.finding_lr = False + # Execute Training Loop start = time.time() self.trainer.fit(