Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
veichta committed Jan 24, 2024
1 parent e3d4515 commit 1c8d315
Showing 1 changed file with 64 additions and 20 deletions.
84 changes: 64 additions & 20 deletions gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment
from .utils.stdout_capturing import capture_outputs
from .utils.tensor import batch_to_device
from .utils.tools import AverageMetric, MedianMetric, PRMetric, RecallMetric, fork_rng, set_seed
from .utils.tools import (
AverageMetric,
MedianMetric,
PRMetric,
RecallMetric,
fork_rng,
set_seed,
)

# @TODO: Fix pbar pollution in logs
# @TODO: add plotting during evaluation
Expand Down Expand Up @@ -79,7 +86,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
if conf.plot is not None:
n, plot_fn = conf.plot
plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False)
for i, data in enumerate(tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)):
for i, data in enumerate(
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
):
data = batch_to_device(data, device, non_blocking=True)
with torch.no_grad():
pred = model(data)
Expand Down Expand Up @@ -149,7 +158,6 @@ def get_lr_scheduler(optimizer, conf):
return getattr(torch.optim.lr_scheduler, conf.type)(
optimizer, schedulers, **conf.options
)

return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)

# backward compatibility
Expand Down Expand Up @@ -183,7 +191,8 @@ def pack_lr_parameters(params, base_lr, lr_scaling):
{s: [n for n, _ in ps] for s, ps in scale2params.items() if s != 1},
)
lr_params = [
{"lr": scale * base_lr, "params": [p for _, p in ps]} for scale, ps in scale2params.items()
{"lr": scale * base_lr, "params": [p for _, p in ps]}
for scale, ps in scale2params.items()
]
return lr_params

Expand Down Expand Up @@ -221,7 +230,9 @@ def training(rank, conf, output_dir, args):
# init_cp = get_last_checkpoint(conf.train.load_experiment)
init_cp = torch.load(str(init_cp), map_location="cpu")
# load the model config of the old setup, and overwrite with current config
conf.model = OmegaConf.merge(OmegaConf.create(init_cp["conf"]).model, conf.model)
conf.model = OmegaConf.merge(
OmegaConf.create(init_cp["conf"]).model, conf.model
)
print(conf.model)
else:
init_cp = None
Expand Down Expand Up @@ -250,7 +261,9 @@ def training(rank, conf, output_dir, args):
if "train_batch_size" in data_conf:
data_conf.train_batch_size = int(data_conf.train_batch_size / args.n_gpus)
if "num_workers" in data_conf:
data_conf.num_workers = int((data_conf.num_workers + args.n_gpus - 1) / args.n_gpus)
data_conf.num_workers = int(
(data_conf.num_workers + args.n_gpus - 1) / args.n_gpus
)
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device {device}")
Expand Down Expand Up @@ -317,7 +330,9 @@ def sigint_handler(signal, frame):
all_params = [p for n, p in params]

lr_params = pack_lr_parameters(params, conf.train.lr, conf.train.lr_scaling)
optimizer = optimizer_fn(lr_params, lr=conf.train.lr, **conf.train.optimizer_options)
optimizer = optimizer_fn(
lr_params, lr=conf.train.lr, **conf.train.optimizer_options
)
scaler = GradScaler(enabled=args.mixed_precision is not None)
logger.info(f"Training with mixed_precision={args.mixed_precision}")

Expand All @@ -336,7 +351,9 @@ def sigint_handler(signal, frame):
lr_scheduler.load_state_dict(init_cp["lr_scheduler"])

if rank == 0:
logger.info("Starting training with configuration:\n%s", OmegaConf.to_yaml(conf))
logger.info(
"Starting training with configuration:\n%s", OmegaConf.to_yaml(conf)
)
losses_ = None

def trace_handler(p):
Expand All @@ -360,7 +377,11 @@ def trace_handler(p):
logger.info(f"Starting epoch {epoch}")

# we first run the eval
if rank == 0 and epoch % conf.train.test_every_epoch == 0 and args.run_benchmarks:
if (
rank == 0
and epoch % conf.train.test_every_epoch == 0
and args.run_benchmarks
):
for bname, eval_conf in conf.get("benchmarks", {}).items():
logger.info(f"Running eval on {bname}")
s, f, r = run_benchmark(
Expand All @@ -382,7 +403,9 @@ def trace_handler(p):
if conf.train.lr_schedule.on_epoch and epoch > 0:
old_lr = optimizer.param_groups[0]["lr"]
lr_scheduler.step()
logger.info(f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}')
logger.info(
f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}'
)
if args.distributed:
train_loader.sampler.set_epoch(epoch)
if epoch > 0 and conf.train.dataset_callback_fn and not args.overfit:
Expand All @@ -395,9 +418,13 @@ def trace_handler(p):
conf.train.seed + epoch
)
else:
getattr(loader.dataset, conf.train.dataset_callback_fn)(conf.train.seed + epoch)
getattr(loader.dataset, conf.train.dataset_callback_fn)(
conf.train.seed + epoch
)
for it, data in enumerate(train_loader):
tot_it = (len(train_loader) * epoch + it) * (args.n_gpus if args.distributed else 1)
tot_it = (len(train_loader) * epoch + it) * (
args.n_gpus if args.distributed else 1
)
tot_n_samples = tot_it
if not args.log_it:
# We normalize the x-axis of tensorflow to num samples!
Expand All @@ -419,7 +446,9 @@ def trace_handler(p):
do_backward = loss.requires_grad
if args.distributed:
do_backward = torch.tensor(do_backward).float().to(device)
torch.distributed.all_reduce(do_backward, torch.distributed.ReduceOp.PRODUCT)
torch.distributed.all_reduce(
do_backward, torch.distributed.ReduceOp.PRODUCT
)
do_backward = do_backward > 0
if do_backward:
scaler.scale(loss).backward()
Expand Down Expand Up @@ -468,11 +497,15 @@ def trace_handler(p):
if rank == 0:
str_losses = [f"{k} {v:.3E}" for k, v in losses.items()]
logger.info(
"[E {} | it {}] loss {{{}}}".format(epoch, it, ", ".join(str_losses))
"[E {} | it {}] loss {{{}}}".format(
epoch, it, ", ".join(str_losses)
)
)
for k, v in losses.items():
writer.add_scalar("training/" + k, v, tot_n_samples)
writer.add_scalar("training/lr", optimizer.param_groups[0]["lr"], tot_n_samples)
writer.add_scalar(
"training/lr", optimizer.param_groups[0]["lr"], tot_n_samples
)
writer.add_scalar("training/epoch", epoch, tot_n_samples)

if conf.train.log_grad_every_iter is not None:
Expand All @@ -482,15 +515,20 @@ def trace_handler(p):
if param.grad is not None and param.requires_grad:
if name.endswith("bias"):
continue
writer.add_histogram(f"grad/{name}", param.grad.detach(), tot_n_samples)
writer.add_histogram(
f"grad/{name}", param.grad.detach(), tot_n_samples
)
norm = torch.norm(param.grad.detach(), 2)
grad_txt += f"{name} {norm.item():.3f} \n"
writer.add_text("grad/summary", grad_txt, tot_n_samples)
del pred, data, loss, losses

# Run validation
if (
(it % conf.train.eval_every_iter == 0 and (it > 0 or epoch == -int(args.no_eval_0)))
(
it % conf.train.eval_every_iter == 0
and (it > 0 or epoch == -int(args.no_eval_0))
)
or stop
or it == (len(train_loader) - 1)
):
Expand All @@ -506,7 +544,9 @@ def trace_handler(p):

if rank == 0:
str_results = [
f"{k} {v:.3E}" for k, v in results.items() if isinstance(v, float)
f"{k} {v:.3E}"
for k, v in results.items()
if isinstance(v, float)
]
logger.info(f'[Validation] {{{", ".join(str_results)}}}')
for k, v in results.items():
Expand Down Expand Up @@ -538,7 +578,9 @@ def trace_handler(p):
if len(figures) > 0:
for i, figs in enumerate(figures):
for name, fig in figs.items():
writer.add_figure(f"figures/{i}_{name}", fig, tot_n_samples)
writer.add_figure(
f"figures/{i}_{name}", fig, tot_n_samples
)
torch.cuda.empty_cache() # should be cleared at the first iter

if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
Expand Down Expand Up @@ -655,6 +697,8 @@ def main_worker(rank, conf, output_dir, args):
args.lock_file = output_dir / "distributed_lock"
if args.lock_file.exists():
args.lock_file.unlink()
torch.multiprocessing.spawn(main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args))
torch.multiprocessing.spawn(
main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args)
)
else:
main_worker(0, conf, output_dir, args)

0 comments on commit 1c8d315

Please sign in to comment.