From e9ba22d51f6f5d605551559704106edc104c2490 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Fri, 22 Nov 2024 01:02:17 +0000 Subject: [PATCH] Adding valid error logging and disabling EMA during LBFGS step --- mace/tools/train.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 36f29013..9d0f9e51 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -266,8 +266,8 @@ def train( epoch=10000 # TODO: fix the code instead of using this workaround if distributed: train_sampler.set_epoch(epoch) - - lbfgs_optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05) + print("HISTORY:20") + lbfgs_optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05, options={'history_size': 20}) train_one_epoch( model=model, loss_fn=loss_fn, @@ -276,7 +276,7 @@ def train( epoch=epoch, output_args=output_args, max_grad_norm=max_grad_norm, - ema=ema, + ema=None, logger=logger, device=device, distributed_model=distributed_model, @@ -286,16 +286,40 @@ def train( if distributed: torch.distributed.barrier() - + + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) param_context = ( ema.average_parameters() if ema is not None else nullcontext() ) + with param_context: + valid_loss = 0.0 checkpoint_handler.save( state=CheckpointState(model, optimizer, lr_scheduler), epochs=epoch, keep_last=keep_last, - ) + ) + + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + logging.info("Training complete")