diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 89751c37..fde15256 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -630,7 +630,6 @@ def run(args: argparse.Namespace) -> None: swa=swa, ema=ema, lbfgs=args.lbfgs, - lbfgs_config=args.lbfgs_config, max_grad_norm=args.clip_grad, log_errors=args.error_table, log_wandb=args.wandb, diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 42efb621..d4062481 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -532,7 +532,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="Optimizer for parameter optimization", type=str, default="adam", - choices=["adam", "adamw", "schedulefree"], + choices=["adam", "adamw", "schedulefree", "lbfgs"], ) parser.add_argument( "--beta", @@ -910,4 +910,4 @@ def parse_dict(value: str): return parsed_dict except (ValueError, SyntaxError) as e: - raise argparse.ArgumentTypeError(f"Invalid dictionary format: {e}") \ No newline at end of file + raise argparse.ArgumentTypeError(f"Invalid dictionary format: {e}") diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 9371e600..2c759eb1 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -21,6 +21,7 @@ from mace import data, modules, tools from mace.tools.train import SWAContainer +from mace.tools.lbfgsnew import LBFGSNew @dataclasses.dataclass @@ -675,6 +676,18 @@ def get_optimizer( ) from exc _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + elif args.optimizer == "lbfgs": + lbfgs_config = args.lbfgs_config + max_iter = lbfgs_config.get("max_iter", 200) + history_size = lbfgs_config.get("history", 240) + batch_mode = lbfgs_config.get("batch_mode", False) + + optimizer = LBFGSNew(**param_options, + tolerance_grad=1e-6, + history_size=history_size, + max_iter=max_iter, + line_search_fn=False, + batch_mode=batch_mode) else: optimizer = torch.optim.Adam(**param_options) return optimizer diff --git a/mace/tools/train.py b/mace/tools/train.py index 92c56f8a..7a95b116 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -19,7 +19,6 @@ from torch.utils.data.distributed import DistributedSampler from torch_ema import ExponentialMovingAverage from torchmetrics import Metric -from .lbfgsnew import LBFGSNew from . import torch_geometric from .checkpoint import CheckpointHandler, CheckpointState @@ -153,7 +152,6 @@ def train( log_errors: str, swa: Optional[SWAContainer] = None, lbfgs: bool = False, - lbfgs_config: Dict = None, ema: Optional[ExponentialMovingAverage] = None, max_grad_norm: Optional[float] = 10.0, log_wandb: bool = False, @@ -231,6 +229,7 @@ def train( device=device, distributed_model=distributed_model, rank=rank, + use_lbfgs=lbfgs ) if distributed: torch.distributed.barrier() @@ -322,85 +321,6 @@ def train( torch.distributed.barrier() epoch += 1 - if lbfgs: - epoch=10000 #TODO: fix code instead of workaround - - lbfgsepochs=lbfgs_config.get("epochs", 50) - max_iter=lbfgs_config.get("max_iter", 200) - history_size=lbfgs_config.get("history", 240) - batch_mode=lbfgs_config.get("batch_mode", False) - - lbfgs_optimizer=LBFGSNew(model.parameters(), - tolerance_grad=1e-6, - history_size=history_size, - max_iter=max_iter, - line_search_fn=False, - batch_mode=batch_mode) - while epoch < 10000+lbfgsepochs: - - if epoch % 10 == 0: - logging.info(f"LBFGS epoch: {epoch}, history: {history_size}, max_iter: {max_iter}, epochs {lbfgsepochs}, batch_mode {batch_mode}") - logging.info("GPU Memory Report:") - logging.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") - logging.info(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") - logging.info(f"Total: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB") - logging.info(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") - logging.info(f"Free memory: {torch.cuda.mem_get_info()[0] / (1024**2):.2f} MB") - - if distributed: - train_sampler.set_epoch() - - train_one_epoch( - model=model, - loss_fn=loss_fn, - data_loader=train_loader, - optimizer=lbfgs_optimizer, - epoch=epoch, - output_args=output_args, - max_grad_norm=max_grad_norm, - ema=ema, - logger=logger, - device=device, - distributed_model=distributed_model, - rank=rank, - lbfgs=lbfgs - ) - - if distributed: - torch.distributed.barrier() - - model_to_evaluate = ( - model if distributed_model is None else distributed_model - ) - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - valid_loss = 0.0 - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - if rank == 0: - valid_err_log( - valid_loss_head, - eval_metrics, - logger, - log_errors, - epoch, - valid_loader_name, - ) - epoch+=1 - logging.info("Training complete") @@ -417,10 +337,10 @@ def train_one_epoch( device: torch.device, distributed_model: Optional[DistributedDataParallel] = None, rank: Optional[int] = 0, - lbfgs: bool = False, + use_lbfgs: bool = False, ) -> None: model_to_train = model if distributed_model is None else distributed_model - take_step_fn = take_lbfgs_step if lbfgs else take_step + take_step_fn = take_lbfgs_step if use_lbfgs else take_step for batch in data_loader: _, opt_metrics = take_step_fn(