Skip to content

Commit

Permalink
Merge pull request #116 from luigibonati/scheduler_fix
Browse files Browse the repository at this point in the history
Fix bug in lr scheduler configure, closes #115
  • Loading branch information
EnricoTrizio authored Jan 16, 2024
2 parents 04596c2 + 6adb5aa commit 973ef81
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mlcolvar/cvs/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def configure_optimizers(self):
if self.lr_scheduler_kwargs:
scheduler_cls = self.lr_scheduler_kwargs['scheduler']
scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'}
lr_scheduler = scheduler_cls(**scheduler_kwargs)
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
return [optimizer] , [lr_scheduler]
else:
return optimizer
Expand Down
25 changes: 25 additions & 0 deletions mlcolvar/tests/test_cvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,28 @@ def test_resume_from_checkpoint(cv_model, dataset):
cv_model.eval()
cv_model2.eval()
assert torch.allclose(cv_model(x), cv_model2(x))

def test_lr_scheduler():

# create dataset
X = torch.randn((100, 2))
y = X.square().sum(1)
dataset = DictDataset({"data": X, "target": y})
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25)

# initialize and pass scheduler to the model
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR
initial_lr = 1e-3
options = {'optimizer' : {'lr' : initial_lr},
'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999}}
model = mlcolvar.cvs.RegressionCV(layers=[2,5,1], options=options)

# check training and lr scheduling
trainer = lightning.Trainer(max_epochs=10,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
enable_model_summary=False)
trainer.fit(model, datamodule)

assert(trainer.optimizers[0].param_groups[0]['lr'] < initial_lr)

0 comments on commit 973ef81

Please sign in to comment.