diff --git a/mlcolvar/cvs/cv.py b/mlcolvar/cvs/cv.py index 64ec5484..09d48d0d 100644 --- a/mlcolvar/cvs/cv.py +++ b/mlcolvar/cvs/cv.py @@ -43,7 +43,6 @@ def __init__( # OPTIM self._optimizer_name = "Adam" self.optimizer_kwargs = {} - self._lr_scheduler_name = None self.lr_scheduler_kwargs = {} # PRE/POST @@ -202,9 +201,9 @@ def configure_optimizers(self): ) if self.lr_scheduler_kwargs: - if self._lr_scheduler_name is None: - self._lr_scheduler_name = self.lr_scheduler_kwargs.pop('scheduler') - lr_scheduler = self._lr_scheduler_name(optimizer, **self.lr_scheduler_kwargs) + scheduler_cls = self.lr_scheduler_kwargs['scheduler'] + scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'} + lr_scheduler = scheduler_cls(**scheduler_kwargs) return [optimizer] , [lr_scheduler] else: return optimizer