Skip to content

Commit

Permalink
add logging of progress and lr
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Aug 28, 2024
1 parent db09100 commit b8bf9b8
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,11 @@ def training_step(self, batch, batch_idx):
optimizer.step()

scheduler = self.lr_schedulers()
scheduler.step()
# scheduler.step(epoch=self.train_progress)
if self.config_train.scheduler == torch.optim.lr_scheduler.OneCycleLR:
# is configured with total_steps (not epochs)
scheduler.step()
else:
scheduler.step(epoch=self.train_progress)

if self.finding_lr:
# Manually track the loss for the lr finder
Expand All @@ -812,6 +815,8 @@ def training_step(self, batch, batch_idx):
self.log_dict(self.metrics_train(predicted_denorm, target_denorm), **self.log_args)
self.log("Loss", loss, **self.log_args)
self.log("RegLoss", reg_loss, **self.log_args)
self.log("TrainProgress", self.train_progress, **self.log_args)
self.log("LR", scheduler.get_last_lr()[0], **self.log_args)
return loss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -873,6 +878,8 @@ def configure_optimizers(self):
self.config_train.set_scheduler()

# Optimizer
if self.finding_lr and self.learning_rate is None:
self.learning_rate = self.config_train.lr_finder_args["min_lr"]
optimizer = self.config_train.optimizer(
self.parameters(),
lr=self.learning_rate,
Expand Down

0 comments on commit b8bf9b8

Please sign in to comment.