diff --git a/mace/tools/train.py b/mace/tools/train.py index ea00dc0f..14095470 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -455,6 +455,7 @@ def validate_and_checkpoint( ) if "ScheduleFree" in type(optimizer).__name__: optimizer.eval() + with param_context: valid_loss = 0.0 wandb_log_dict = {} @@ -487,50 +488,54 @@ def validate_and_checkpoint( valid_loss = ( valid_loss_head # consider only the last head for the checkpoint ) + if log_wandb: import wandb wandb.log(wandb_log_dict) - if rank == 0: - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience: - if swa is not None and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - else: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - return True, epoch, lowest_loss, patience_counter - if save_all_checkpoints: - param_context = ( - ema.average_parameters() - if ema is not None - else nullcontext() + + if rank != 0: + return False, epoch, lowest_loss, patience_counter + + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - else: - lowest_loss = valid_loss - patience_counter = 0 + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + return True, epoch, lowest_loss, patience_counter + if save_all_checkpoints: param_context = ( - ema.average_parameters() if ema is not None else nullcontext() + ema.average_parameters() + if ema is not None + else nullcontext() ) with param_context: checkpoint_handler.save( state=CheckpointState(model, optimizer, lr_scheduler), epochs=epoch, - keep_last=keep_last, + keep_last=True, ) - keep_last = False or save_all_checkpoints + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints - return False, epoch, lowest_loss, patience_counter + return False, epoch, lowest_loss, patience_counter def evaluate(