diff --git a/mace/tools/train.py b/mace/tools/train.py index b5425314..ea00dc0f 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -167,8 +167,6 @@ def train( patience_counter = 0 swa_start = True keep_last = False - if log_wandb: - import wandb if max_grad_norm is not None: logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") @@ -236,87 +234,29 @@ def train( # Validate if epoch % eval_interval == 0: - model_to_evaluate = ( - model if distributed_model is None else distributed_model - ) - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() + epoch, should_stop, lowest_loss, patience_counter = validate_and_checkpoint( + model=model, + loss_fn=loss_fn, + valid_loaders=valid_loaders, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ema=ema, + checkpoint_handler=checkpoint_handler, + logger=logger, + log_errors=log_errors, + swa=swa, + epoch=epoch, + lowest_loss=lowest_loss, + patience_counter=patience_counter, + patience=patience, + save_all_checkpoints=save_all_checkpoints, + log_wandb=log_wandb, + distributed_model=distributed_model, + output_args=output_args, + device=device, + rank=rank ) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.eval() - with param_context: - valid_loss = 0.0 - wandb_log_dict = {} - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - if rank == 0: - valid_err_log( - valid_loss_head, - eval_metrics, - logger, - log_errors, - epoch, - valid_loader_name, - ) - if log_wandb: - wandb_log_dict[valid_loader_name] = { - "epoch": epoch, - "valid_loss": valid_loss_head, - "valid_rmse_e_per_atom": eval_metrics[ - "rmse_e_per_atom" - ], - "valid_rmse_f": eval_metrics["rmse_f"], - } - valid_loss = ( - valid_loss_head # consider only the last head for the checkpoint - ) - if log_wandb: - wandb.log(wandb_log_dict) - if rank == 0: - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience: - if swa is not None and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - else: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - if save_all_checkpoints: - param_context = ( - ema.average_parameters() - if ema is not None - else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - else: - lowest_loss = valid_loss - patience_counter = 0 - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - keep_last = False or save_all_checkpoints + if should_stop: break if distributed: torch.distributed.barrier() epoch += 1 @@ -325,10 +265,12 @@ def train( if distributed: train_sampler.set_epoch(epoch) - train_one_epoch_lbfgs( + lbfgs_optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05) + train_one_epoch( model=model, loss_fn=loss_fn, data_loader=train_loader, + optimizer=lbfgs_optimizer, epoch=epoch, output_args=output_args, max_grad_norm=max_grad_norm, @@ -337,6 +279,7 @@ def train( device=device, distributed_model=distributed_model, rank=rank, + lbfgs=lbfgs ) if distributed: @@ -368,25 +311,38 @@ def train_one_epoch( device: torch.device, distributed_model: Optional[DistributedDataParallel] = None, rank: Optional[int] = 0, + lbfgs: bool = False, ) -> None: model_to_train = model if distributed_model is None else distributed_model for batch in data_loader: - _, opt_metrics = take_step( - model=model_to_train, - loss_fn=loss_fn, - batch=batch, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - ) + if lbfgs: + _, opt_metrics = take_step_lbfgs( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) + else: + _, opt_metrics = take_step( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) opt_metrics["mode"] = "opt" opt_metrics["epoch"] = epoch if rank == 0: logger.log(opt_metrics) - - + + def take_step( model: torch.nn.Module, loss_fn: torch.nn.Module, @@ -425,112 +381,191 @@ def take_step( return loss, loss_dict -def evaluate( +def take_step_lbfgs( model: torch.nn.Module, loss_fn: torch.nn.Module, - data_loader: DataLoader, + batch: torch_geometric.batch.Batch, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], output_args: Dict[str, bool], + max_grad_norm: Optional[float], device: torch.device, ) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False - - metrics = MACELoss(loss_fn=loss_fn).to(device) - + start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() + batch_ = batch.to(device) + + def closure(): + optimizer.zero_grad() + batch_dict = batch_.to_dict() output = model( batch_dict, - training=False, + training=True, compute_force=output_args["forces"], compute_virials=output_args["virials"], compute_stress=output_args["stress"], ) - avg_loss, aux = metrics(batch, output) + loss = loss_fn(pred=output, ref=batch_) + return loss - avg_loss, aux = metrics.compute() - aux["time"] = time.time() - start_time - metrics.reset() + optimizer.step(closure) + loss = closure() + + if max_grad_norm is not None and loss.requires_grad: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - for param in model.parameters(): - param.requires_grad = True + if ema is not None: + ema.update() - return avg_loss, aux + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict -def train_one_epoch_lbfgs( +def validate_and_checkpoint( model: torch.nn.Module, loss_fn: torch.nn.Module, - data_loader: DataLoader, - epoch: int, - output_args: Dict[str, bool], - max_grad_norm: Optional[float], + valid_loaders: Dict[str, DataLoader], + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, ema: Optional[ExponentialMovingAverage], + checkpoint_handler: CheckpointHandler, logger: MetricsLogger, + log_errors: str, + swa: Optional[SWAContainer], + epoch: int, + lowest_loss: float, + patience_counter: int, + patience: int, + save_all_checkpoints: bool, + log_wandb: bool, + distributed_model: Optional[DistributedDataParallel], + output_args: Dict[str, bool], device: torch.device, - distributed_model: Optional[DistributedDataParallel] = None, rank: Optional[int] = 0, -) -> None: - model_to_train = model if distributed_model is None else distributed_model - for batch in data_loader: - _, opt_metrics = take_step_lbfgs( - model=model_to_train, - loss_fn=loss_fn, - batch=batch, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, +) -> Tuple[bool, int, float, int]: + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() + with param_context: + valid_loss = 0.0 + wandb_log_dict = {} + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + if log_wandb: + wandb_log_dict[valid_loader_name] = { + "epoch": epoch, + "valid_loss": valid_loss_head, + "valid_rmse_e_per_atom": eval_metrics[ + "rmse_e_per_atom" + ], + "valid_rmse_f": eval_metrics["rmse_f"], + } + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) + if log_wandb: + import wandb + wandb.log(wandb_log_dict) + if rank == 0: + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + return True, epoch, lowest_loss, patience_counter + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints + + return False, epoch, lowest_loss, patience_counter -def take_step_lbfgs( +def evaluate( model: torch.nn.Module, loss_fn: torch.nn.Module, - batch: torch_geometric.batch.Batch, - ema: Optional[ExponentialMovingAverage], + data_loader: DataLoader, output_args: Dict[str, bool], - max_grad_norm: Optional[float], device: torch.device, ) -> Tuple[float, Dict[str, Any]]: - batch_ = batch.to(device) - optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05) + for param in model.parameters(): + param.requires_grad = False - def closure(): - optimizer.zero_grad() - batch_dict = batch_.to_dict() + metrics = MACELoss(loss_fn=loss_fn).to(device) + + start_time = time.time() + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() output = model( batch_dict, - training=True, + training=False, compute_force=output_args["forces"], compute_virials=output_args["virials"], compute_stress=output_args["stress"], ) - loss = loss_fn(pred=output, ref=batch_) - return loss - - start_time = time.time() - optimizer.step(closure) - loss = closure() - - if max_grad_norm is not None and loss.requires_grad: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + avg_loss, aux = metrics(batch, output) - if ema is not None: - ema.update() + avg_loss, aux = metrics.compute() + aux["time"] = time.time() - start_time + metrics.reset() - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } + for param in model.parameters(): + param.requires_grad = True - return loss, loss_dict + return avg_loss, aux class MACELoss(Metric):