diff --git a/gluefactory/train.py b/gluefactory/train.py index debf2125..91f66002 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -145,6 +145,19 @@ def filter_fn(x): def get_lr_scheduler(optimizer, conf): """Get lr scheduler specified by conf.train.lr_schedule.""" if conf.type not in ["factor", "exp", None]: + if hasattr(conf.options, "schedulers"): + # Add option to chain multiple schedulers together + # This is useful for e.g. warmup, then cosine decay + schedulers = [] + for scheduler_conf in conf.options.schedulers: + scheduler = get_lr_scheduler(optimizer, scheduler_conf) + schedulers.append(scheduler) + + options = {k: v for k, v in conf.options.items() if k != "schedulers"} + return getattr(torch.optim.lr_scheduler, conf.type)( + optimizer, schedulers, **options + ) + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) # backward compatibility