-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major] Support Custom Learning Rate Scheduler #1637
Conversation
Model Benchmark
|
neuralprophet/time_net.py
Outdated
if self.finding_lr: | ||
# Manually track the loss for the lr finder | ||
self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | ||
self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip reg_loss
self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | ||
if self.finding_lr: | ||
# Manually track the loss for the lr finder | ||
self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass log_args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for lr-finder requirements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implement better prgress/metrics logging/printing.
e.g. also need to touch time_net.__init__
self.log_args = {
"on_step": False,
"on_epoch": True,
"prog_bar": True,
"batch_size": self.config_train.batch_size,
}
loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) | ||
if self.config_train.newer_samples_weight > 1.0: | ||
# Weigh newer samples more. | ||
loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify to only pass first forecast target
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do another time
lr_scheduler = self.config_train.scheduler( | ||
optimizer, | ||
**self.config_train.scheduler_args, | ||
) | ||
|
||
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} | ||
|
||
def _get_time_based_sample_weight(self, t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify to compute based only on first forecast target
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
postponed
Changes:
Future TODOs: