From e3d4515788b4341b324823b3a0e067208c3443e7 Mon Sep 17 00:00:00 2001 From: Alexander Veicht Date: Wed, 24 Jan 2024 11:34:34 +0100 Subject: [PATCH 1/4] Add option to chain lr schedules. --- gluefactory/train.py | 99 +++++++++++++++----------------------------- 1 file changed, 34 insertions(+), 65 deletions(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index debf2125..bf9b44dd 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -28,14 +28,7 @@ from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment from .utils.stdout_capturing import capture_outputs from .utils.tensor import batch_to_device -from .utils.tools import ( - AverageMetric, - MedianMetric, - PRMetric, - RecallMetric, - fork_rng, - set_seed, -) +from .utils.tools import AverageMetric, MedianMetric, PRMetric, RecallMetric, fork_rng, set_seed # @TODO: Fix pbar pollution in logs # @TODO: add plotting during evaluation @@ -86,9 +79,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): if conf.plot is not None: n, plot_fn = conf.plot plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False) - for i, data in enumerate( - tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar) - ): + for i, data in enumerate(tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)): data = batch_to_device(data, device, non_blocking=True) with torch.no_grad(): pred = model(data) @@ -143,8 +134,22 @@ def filter_fn(x): def get_lr_scheduler(optimizer, conf): - """Get lr scheduler specified by conf.train.lr_schedule.""" + """Get lr scheduler specified by conf.""" if conf.type not in ["factor", "exp", None]: + if hasattr(conf.options, "schedulers"): + # Add option to chain multiple schedulers together + # This is useful for e.g. warmup, then cosine decay + schedulers = [] + for scheduler_conf in conf.options.schedulers: + scheduler = get_lr_scheduler(optimizer, scheduler_conf) + schedulers.append(scheduler) + + # remove conf.options.schedulers + del conf.options.schedulers + return getattr(torch.optim.lr_scheduler, conf.type)( + optimizer, schedulers, **conf.options + ) + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) # backward compatibility @@ -178,8 +183,7 @@ def pack_lr_parameters(params, base_lr, lr_scaling): {s: [n for n, _ in ps] for s, ps in scale2params.items() if s != 1}, ) lr_params = [ - {"lr": scale * base_lr, "params": [p for _, p in ps]} - for scale, ps in scale2params.items() + {"lr": scale * base_lr, "params": [p for _, p in ps]} for scale, ps in scale2params.items() ] return lr_params @@ -217,9 +221,7 @@ def training(rank, conf, output_dir, args): # init_cp = get_last_checkpoint(conf.train.load_experiment) init_cp = torch.load(str(init_cp), map_location="cpu") # load the model config of the old setup, and overwrite with current config - conf.model = OmegaConf.merge( - OmegaConf.create(init_cp["conf"]).model, conf.model - ) + conf.model = OmegaConf.merge(OmegaConf.create(init_cp["conf"]).model, conf.model) print(conf.model) else: init_cp = None @@ -248,9 +250,7 @@ def training(rank, conf, output_dir, args): if "train_batch_size" in data_conf: data_conf.train_batch_size = int(data_conf.train_batch_size / args.n_gpus) if "num_workers" in data_conf: - data_conf.num_workers = int( - (data_conf.num_workers + args.n_gpus - 1) / args.n_gpus - ) + data_conf.num_workers = int((data_conf.num_workers + args.n_gpus - 1) / args.n_gpus) else: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device {device}") @@ -317,9 +317,7 @@ def sigint_handler(signal, frame): all_params = [p for n, p in params] lr_params = pack_lr_parameters(params, conf.train.lr, conf.train.lr_scaling) - optimizer = optimizer_fn( - lr_params, lr=conf.train.lr, **conf.train.optimizer_options - ) + optimizer = optimizer_fn(lr_params, lr=conf.train.lr, **conf.train.optimizer_options) scaler = GradScaler(enabled=args.mixed_precision is not None) logger.info(f"Training with mixed_precision={args.mixed_precision}") @@ -338,9 +336,7 @@ def sigint_handler(signal, frame): lr_scheduler.load_state_dict(init_cp["lr_scheduler"]) if rank == 0: - logger.info( - "Starting training with configuration:\n%s", OmegaConf.to_yaml(conf) - ) + logger.info("Starting training with configuration:\n%s", OmegaConf.to_yaml(conf)) losses_ = None def trace_handler(p): @@ -364,11 +360,7 @@ def trace_handler(p): logger.info(f"Starting epoch {epoch}") # we first run the eval - if ( - rank == 0 - and epoch % conf.train.test_every_epoch == 0 - and args.run_benchmarks - ): + if rank == 0 and epoch % conf.train.test_every_epoch == 0 and args.run_benchmarks: for bname, eval_conf in conf.get("benchmarks", {}).items(): logger.info(f"Running eval on {bname}") s, f, r = run_benchmark( @@ -390,9 +382,7 @@ def trace_handler(p): if conf.train.lr_schedule.on_epoch and epoch > 0: old_lr = optimizer.param_groups[0]["lr"] lr_scheduler.step() - logger.info( - f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}' - ) + logger.info(f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}') if args.distributed: train_loader.sampler.set_epoch(epoch) if epoch > 0 and conf.train.dataset_callback_fn and not args.overfit: @@ -405,13 +395,9 @@ def trace_handler(p): conf.train.seed + epoch ) else: - getattr(loader.dataset, conf.train.dataset_callback_fn)( - conf.train.seed + epoch - ) + getattr(loader.dataset, conf.train.dataset_callback_fn)(conf.train.seed + epoch) for it, data in enumerate(train_loader): - tot_it = (len(train_loader) * epoch + it) * ( - args.n_gpus if args.distributed else 1 - ) + tot_it = (len(train_loader) * epoch + it) * (args.n_gpus if args.distributed else 1) tot_n_samples = tot_it if not args.log_it: # We normalize the x-axis of tensorflow to num samples! @@ -433,9 +419,7 @@ def trace_handler(p): do_backward = loss.requires_grad if args.distributed: do_backward = torch.tensor(do_backward).float().to(device) - torch.distributed.all_reduce( - do_backward, torch.distributed.ReduceOp.PRODUCT - ) + torch.distributed.all_reduce(do_backward, torch.distributed.ReduceOp.PRODUCT) do_backward = do_backward > 0 if do_backward: scaler.scale(loss).backward() @@ -484,15 +468,11 @@ def trace_handler(p): if rank == 0: str_losses = [f"{k} {v:.3E}" for k, v in losses.items()] logger.info( - "[E {} | it {}] loss {{{}}}".format( - epoch, it, ", ".join(str_losses) - ) + "[E {} | it {}] loss {{{}}}".format(epoch, it, ", ".join(str_losses)) ) for k, v in losses.items(): writer.add_scalar("training/" + k, v, tot_n_samples) - writer.add_scalar( - "training/lr", optimizer.param_groups[0]["lr"], tot_n_samples - ) + writer.add_scalar("training/lr", optimizer.param_groups[0]["lr"], tot_n_samples) writer.add_scalar("training/epoch", epoch, tot_n_samples) if conf.train.log_grad_every_iter is not None: @@ -502,9 +482,7 @@ def trace_handler(p): if param.grad is not None and param.requires_grad: if name.endswith("bias"): continue - writer.add_histogram( - f"grad/{name}", param.grad.detach(), tot_n_samples - ) + writer.add_histogram(f"grad/{name}", param.grad.detach(), tot_n_samples) norm = torch.norm(param.grad.detach(), 2) grad_txt += f"{name} {norm.item():.3f} \n" writer.add_text("grad/summary", grad_txt, tot_n_samples) @@ -512,10 +490,7 @@ def trace_handler(p): # Run validation if ( - ( - it % conf.train.eval_every_iter == 0 - and (it > 0 or epoch == -int(args.no_eval_0)) - ) + (it % conf.train.eval_every_iter == 0 and (it > 0 or epoch == -int(args.no_eval_0))) or stop or it == (len(train_loader) - 1) ): @@ -531,9 +506,7 @@ def trace_handler(p): if rank == 0: str_results = [ - f"{k} {v:.3E}" - for k, v in results.items() - if isinstance(v, float) + f"{k} {v:.3E}" for k, v in results.items() if isinstance(v, float) ] logger.info(f'[Validation] {{{", ".join(str_results)}}}') for k, v in results.items(): @@ -565,9 +538,7 @@ def trace_handler(p): if len(figures) > 0: for i, figs in enumerate(figures): for name, fig in figs.items(): - writer.add_figure( - f"figures/{i}_{name}", fig, tot_n_samples - ) + writer.add_figure(f"figures/{i}_{name}", fig, tot_n_samples) torch.cuda.empty_cache() # should be cleared at the first iter if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0: @@ -684,8 +655,6 @@ def main_worker(rank, conf, output_dir, args): args.lock_file = output_dir / "distributed_lock" if args.lock_file.exists(): args.lock_file.unlink() - torch.multiprocessing.spawn( - main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args) - ) + torch.multiprocessing.spawn(main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args)) else: main_worker(0, conf, output_dir, args) From 1c8d315ddcdfd45d0d6e834d7a930b65c8b04cd1 Mon Sep 17 00:00:00 2001 From: Alexander Veicht Date: Wed, 24 Jan 2024 12:49:38 +0100 Subject: [PATCH 2/4] Fix formatting. --- gluefactory/train.py | 84 +++++++++++++++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index bf9b44dd..f885b597 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -28,7 +28,14 @@ from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment from .utils.stdout_capturing import capture_outputs from .utils.tensor import batch_to_device -from .utils.tools import AverageMetric, MedianMetric, PRMetric, RecallMetric, fork_rng, set_seed +from .utils.tools import ( + AverageMetric, + MedianMetric, + PRMetric, + RecallMetric, + fork_rng, + set_seed, +) # @TODO: Fix pbar pollution in logs # @TODO: add plotting during evaluation @@ -79,7 +86,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): if conf.plot is not None: n, plot_fn = conf.plot plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False) - for i, data in enumerate(tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)): + for i, data in enumerate( + tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar) + ): data = batch_to_device(data, device, non_blocking=True) with torch.no_grad(): pred = model(data) @@ -149,7 +158,6 @@ def get_lr_scheduler(optimizer, conf): return getattr(torch.optim.lr_scheduler, conf.type)( optimizer, schedulers, **conf.options ) - return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) # backward compatibility @@ -183,7 +191,8 @@ def pack_lr_parameters(params, base_lr, lr_scaling): {s: [n for n, _ in ps] for s, ps in scale2params.items() if s != 1}, ) lr_params = [ - {"lr": scale * base_lr, "params": [p for _, p in ps]} for scale, ps in scale2params.items() + {"lr": scale * base_lr, "params": [p for _, p in ps]} + for scale, ps in scale2params.items() ] return lr_params @@ -221,7 +230,9 @@ def training(rank, conf, output_dir, args): # init_cp = get_last_checkpoint(conf.train.load_experiment) init_cp = torch.load(str(init_cp), map_location="cpu") # load the model config of the old setup, and overwrite with current config - conf.model = OmegaConf.merge(OmegaConf.create(init_cp["conf"]).model, conf.model) + conf.model = OmegaConf.merge( + OmegaConf.create(init_cp["conf"]).model, conf.model + ) print(conf.model) else: init_cp = None @@ -250,7 +261,9 @@ def training(rank, conf, output_dir, args): if "train_batch_size" in data_conf: data_conf.train_batch_size = int(data_conf.train_batch_size / args.n_gpus) if "num_workers" in data_conf: - data_conf.num_workers = int((data_conf.num_workers + args.n_gpus - 1) / args.n_gpus) + data_conf.num_workers = int( + (data_conf.num_workers + args.n_gpus - 1) / args.n_gpus + ) else: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device {device}") @@ -317,7 +330,9 @@ def sigint_handler(signal, frame): all_params = [p for n, p in params] lr_params = pack_lr_parameters(params, conf.train.lr, conf.train.lr_scaling) - optimizer = optimizer_fn(lr_params, lr=conf.train.lr, **conf.train.optimizer_options) + optimizer = optimizer_fn( + lr_params, lr=conf.train.lr, **conf.train.optimizer_options + ) scaler = GradScaler(enabled=args.mixed_precision is not None) logger.info(f"Training with mixed_precision={args.mixed_precision}") @@ -336,7 +351,9 @@ def sigint_handler(signal, frame): lr_scheduler.load_state_dict(init_cp["lr_scheduler"]) if rank == 0: - logger.info("Starting training with configuration:\n%s", OmegaConf.to_yaml(conf)) + logger.info( + "Starting training with configuration:\n%s", OmegaConf.to_yaml(conf) + ) losses_ = None def trace_handler(p): @@ -360,7 +377,11 @@ def trace_handler(p): logger.info(f"Starting epoch {epoch}") # we first run the eval - if rank == 0 and epoch % conf.train.test_every_epoch == 0 and args.run_benchmarks: + if ( + rank == 0 + and epoch % conf.train.test_every_epoch == 0 + and args.run_benchmarks + ): for bname, eval_conf in conf.get("benchmarks", {}).items(): logger.info(f"Running eval on {bname}") s, f, r = run_benchmark( @@ -382,7 +403,9 @@ def trace_handler(p): if conf.train.lr_schedule.on_epoch and epoch > 0: old_lr = optimizer.param_groups[0]["lr"] lr_scheduler.step() - logger.info(f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}') + logger.info( + f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}' + ) if args.distributed: train_loader.sampler.set_epoch(epoch) if epoch > 0 and conf.train.dataset_callback_fn and not args.overfit: @@ -395,9 +418,13 @@ def trace_handler(p): conf.train.seed + epoch ) else: - getattr(loader.dataset, conf.train.dataset_callback_fn)(conf.train.seed + epoch) + getattr(loader.dataset, conf.train.dataset_callback_fn)( + conf.train.seed + epoch + ) for it, data in enumerate(train_loader): - tot_it = (len(train_loader) * epoch + it) * (args.n_gpus if args.distributed else 1) + tot_it = (len(train_loader) * epoch + it) * ( + args.n_gpus if args.distributed else 1 + ) tot_n_samples = tot_it if not args.log_it: # We normalize the x-axis of tensorflow to num samples! @@ -419,7 +446,9 @@ def trace_handler(p): do_backward = loss.requires_grad if args.distributed: do_backward = torch.tensor(do_backward).float().to(device) - torch.distributed.all_reduce(do_backward, torch.distributed.ReduceOp.PRODUCT) + torch.distributed.all_reduce( + do_backward, torch.distributed.ReduceOp.PRODUCT + ) do_backward = do_backward > 0 if do_backward: scaler.scale(loss).backward() @@ -468,11 +497,15 @@ def trace_handler(p): if rank == 0: str_losses = [f"{k} {v:.3E}" for k, v in losses.items()] logger.info( - "[E {} | it {}] loss {{{}}}".format(epoch, it, ", ".join(str_losses)) + "[E {} | it {}] loss {{{}}}".format( + epoch, it, ", ".join(str_losses) + ) ) for k, v in losses.items(): writer.add_scalar("training/" + k, v, tot_n_samples) - writer.add_scalar("training/lr", optimizer.param_groups[0]["lr"], tot_n_samples) + writer.add_scalar( + "training/lr", optimizer.param_groups[0]["lr"], tot_n_samples + ) writer.add_scalar("training/epoch", epoch, tot_n_samples) if conf.train.log_grad_every_iter is not None: @@ -482,7 +515,9 @@ def trace_handler(p): if param.grad is not None and param.requires_grad: if name.endswith("bias"): continue - writer.add_histogram(f"grad/{name}", param.grad.detach(), tot_n_samples) + writer.add_histogram( + f"grad/{name}", param.grad.detach(), tot_n_samples + ) norm = torch.norm(param.grad.detach(), 2) grad_txt += f"{name} {norm.item():.3f} \n" writer.add_text("grad/summary", grad_txt, tot_n_samples) @@ -490,7 +525,10 @@ def trace_handler(p): # Run validation if ( - (it % conf.train.eval_every_iter == 0 and (it > 0 or epoch == -int(args.no_eval_0))) + ( + it % conf.train.eval_every_iter == 0 + and (it > 0 or epoch == -int(args.no_eval_0)) + ) or stop or it == (len(train_loader) - 1) ): @@ -506,7 +544,9 @@ def trace_handler(p): if rank == 0: str_results = [ - f"{k} {v:.3E}" for k, v in results.items() if isinstance(v, float) + f"{k} {v:.3E}" + for k, v in results.items() + if isinstance(v, float) ] logger.info(f'[Validation] {{{", ".join(str_results)}}}') for k, v in results.items(): @@ -538,7 +578,9 @@ def trace_handler(p): if len(figures) > 0: for i, figs in enumerate(figures): for name, fig in figs.items(): - writer.add_figure(f"figures/{i}_{name}", fig, tot_n_samples) + writer.add_figure( + f"figures/{i}_{name}", fig, tot_n_samples + ) torch.cuda.empty_cache() # should be cleared at the first iter if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0: @@ -655,6 +697,8 @@ def main_worker(rank, conf, output_dir, args): args.lock_file = output_dir / "distributed_lock" if args.lock_file.exists(): args.lock_file.unlink() - torch.multiprocessing.spawn(main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args)) + torch.multiprocessing.spawn( + main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args) + ) else: main_worker(0, conf, output_dir, args) From 5863345b81cecdeef92b432b9acafe8a5e454f52 Mon Sep 17 00:00:00 2001 From: Alexander Veicht Date: Wed, 24 Jan 2024 12:50:40 +0100 Subject: [PATCH 3/4] Fix docstring. --- gluefactory/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index f885b597..ece43743 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -143,7 +143,7 @@ def filter_fn(x): def get_lr_scheduler(optimizer, conf): - """Get lr scheduler specified by conf.""" + """Get lr scheduler specified by conf.train.lr_schedule.""" if conf.type not in ["factor", "exp", None]: if hasattr(conf.options, "schedulers"): # Add option to chain multiple schedulers together From 84b9a125d753d6de66125aa57c7924808c871f39 Mon Sep 17 00:00:00 2001 From: Alexander Veicht Date: Wed, 24 Jan 2024 13:14:59 +0100 Subject: [PATCH 4/4] Cleaner way to get options. --- gluefactory/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gluefactory/train.py b/gluefactory/train.py index ece43743..91f66002 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -153,11 +153,11 @@ def get_lr_scheduler(optimizer, conf): scheduler = get_lr_scheduler(optimizer, scheduler_conf) schedulers.append(scheduler) - # remove conf.options.schedulers - del conf.options.schedulers + options = {k: v for k, v in conf.options.items() if k != "schedulers"} return getattr(torch.optim.lr_scheduler, conf.type)( - optimizer, schedulers, **conf.options + optimizer, schedulers, **options ) + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) # backward compatibility