diff --git a/mlcolvar/tests/test_cvs.py b/mlcolvar/tests/test_cvs.py index a76c1a99..c7bc0c30 100644 --- a/mlcolvar/tests/test_cvs.py +++ b/mlcolvar/tests/test_cvs.py @@ -98,3 +98,28 @@ def test_resume_from_checkpoint(cv_model, dataset): cv_model.eval() cv_model2.eval() assert torch.allclose(cv_model(x), cv_model2(x)) + +def test_lr_scheduler(): + + # create dataset + X = torch.randn((100, 2)) + y = X.square().sum(1) + dataset = DictDataset({"data": X, "target": y}) + datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25) + + # initialize and pass scheduler to the model + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR + initial_lr = 1e-3 + options = {'optimizer' : {'lr' : initial_lr}, + 'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999}} + model = mlcolvar.cvs.RegressionCV(layers=[2,5,1], options=options) + + # check training and lr scheduling + trainer = lightning.Trainer(max_epochs=10, + enable_checkpointing=False, + logger=False, + enable_progress_bar=False, + enable_model_summary=False) + trainer.fit(model, datamodule) + + assert(trainer.optimizers[0].param_groups[0]['lr'] < initial_lr) \ No newline at end of file