diff --git a/gluefactory/train.py b/gluefactory/train.py index ece43743..91f66002 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -153,11 +153,11 @@ def get_lr_scheduler(optimizer, conf): scheduler = get_lr_scheduler(optimizer, scheduler_conf) schedulers.append(scheduler) - # remove conf.options.schedulers - del conf.options.schedulers + options = {k: v for k, v in conf.options.items() if k != "schedulers"} return getattr(torch.optim.lr_scheduler, conf.type)( - optimizer, schedulers, **conf.options + optimizer, schedulers, **options ) + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) # backward compatibility