Skip to content

Commit

Permalink
Added test for lr_scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Jan 15, 2024
1 parent 911f3ef commit 6adb5aa
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions mlcolvar/tests/test_cvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,28 @@ def test_resume_from_checkpoint(cv_model, dataset):
cv_model.eval()
cv_model2.eval()
assert torch.allclose(cv_model(x), cv_model2(x))

def test_lr_scheduler():

# create dataset
X = torch.randn((100, 2))
y = X.square().sum(1)
dataset = DictDataset({"data": X, "target": y})
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25)

# initialize and pass scheduler to the model
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR
initial_lr = 1e-3
options = {'optimizer' : {'lr' : initial_lr},
'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999}}
model = mlcolvar.cvs.RegressionCV(layers=[2,5,1], options=options)

# check training and lr scheduling
trainer = lightning.Trainer(max_epochs=10,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
enable_model_summary=False)
trainer.fit(model, datamodule)

assert(trainer.optimizers[0].param_groups[0]['lr'] < initial_lr)

0 comments on commit 6adb5aa

Please sign in to comment.