Skip to content

Commit

Permalink
add lbfgs as an option for the optimizer argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa committed Dec 17, 2024
1 parent 2677ae2 commit 337a607
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 86 deletions.
1 change: 0 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ def run(args: argparse.Namespace) -> None:
swa=swa,
ema=ema,
lbfgs=args.lbfgs,
lbfgs_config=args.lbfgs_config,
max_grad_norm=args.clip_grad,
log_errors=args.error_table,
log_wandb=args.wandb,
Expand Down
4 changes: 2 additions & 2 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="Optimizer for parameter optimization",
type=str,
default="adam",
choices=["adam", "adamw", "schedulefree"],
choices=["adam", "adamw", "schedulefree", "lbfgs"],
)
parser.add_argument(
"--beta",
Expand Down Expand Up @@ -910,4 +910,4 @@ def parse_dict(value: str):

return parsed_dict
except (ValueError, SyntaxError) as e:
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {e}")
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {e}")
13 changes: 13 additions & 0 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from mace import data, modules, tools
from mace.tools.train import SWAContainer
from mace.tools.lbfgsnew import LBFGSNew


@dataclasses.dataclass
Expand Down Expand Up @@ -675,6 +676,18 @@ def get_optimizer(
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
elif args.optimizer == "lbfgs":
lbfgs_config = args.lbfgs_config
max_iter = lbfgs_config.get("max_iter", 200)
history_size = lbfgs_config.get("history", 240)
batch_mode = lbfgs_config.get("batch_mode", False)

optimizer = LBFGSNew(**param_options,
tolerance_grad=1e-6,
history_size=history_size,
max_iter=max_iter,
line_search_fn=False,
batch_mode=batch_mode)
else:
optimizer = torch.optim.Adam(**param_options)
return optimizer
Expand Down
86 changes: 3 additions & 83 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch.utils.data.distributed import DistributedSampler
from torch_ema import ExponentialMovingAverage
from torchmetrics import Metric
from .lbfgsnew import LBFGSNew

from . import torch_geometric
from .checkpoint import CheckpointHandler, CheckpointState
Expand Down Expand Up @@ -153,7 +152,6 @@ def train(
log_errors: str,
swa: Optional[SWAContainer] = None,
lbfgs: bool = False,
lbfgs_config: Dict = None,
ema: Optional[ExponentialMovingAverage] = None,
max_grad_norm: Optional[float] = 10.0,
log_wandb: bool = False,
Expand Down Expand Up @@ -231,6 +229,7 @@ def train(
device=device,
distributed_model=distributed_model,
rank=rank,
use_lbfgs=lbfgs
)
if distributed:
torch.distributed.barrier()
Expand Down Expand Up @@ -322,85 +321,6 @@ def train(
torch.distributed.barrier()
epoch += 1

if lbfgs:
epoch=10000 #TODO: fix code instead of workaround

lbfgsepochs=lbfgs_config.get("epochs", 50)
max_iter=lbfgs_config.get("max_iter", 200)
history_size=lbfgs_config.get("history", 240)
batch_mode=lbfgs_config.get("batch_mode", False)

lbfgs_optimizer=LBFGSNew(model.parameters(),
tolerance_grad=1e-6,
history_size=history_size,
max_iter=max_iter,
line_search_fn=False,
batch_mode=batch_mode)
while epoch < 10000+lbfgsepochs:

if epoch % 10 == 0:
logging.info(f"LBFGS epoch: {epoch}, history: {history_size}, max_iter: {max_iter}, epochs {lbfgsepochs}, batch_mode {batch_mode}")
logging.info("GPU Memory Report:")
logging.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
logging.info(f"Total: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")
logging.info(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Free memory: {torch.cuda.mem_get_info()[0] / (1024**2):.2f} MB")

if distributed:
train_sampler.set_epoch()

train_one_epoch(
model=model,
loss_fn=loss_fn,
data_loader=train_loader,
optimizer=lbfgs_optimizer,
epoch=epoch,
output_args=output_args,
max_grad_norm=max_grad_norm,
ema=ema,
logger=logger,
device=device,
distributed_model=distributed_model,
rank=rank,
lbfgs=lbfgs
)

if distributed:
torch.distributed.barrier()

model_to_evaluate = (
model if distributed_model is None else distributed_model
)
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
with param_context:
valid_loss = 0.0
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model_to_evaluate,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
if rank == 0:
valid_err_log(
valid_loss_head,
eval_metrics,
logger,
log_errors,
epoch,
valid_loader_name,
)
epoch+=1

logging.info("Training complete")


Expand All @@ -417,10 +337,10 @@ def train_one_epoch(
device: torch.device,
distributed_model: Optional[DistributedDataParallel] = None,
rank: Optional[int] = 0,
lbfgs: bool = False,
use_lbfgs: bool = False,
) -> None:
model_to_train = model if distributed_model is None else distributed_model
take_step_fn = take_lbfgs_step if lbfgs else take_step
take_step_fn = take_lbfgs_step if use_lbfgs else take_step

for batch in data_loader:
_, opt_metrics = take_step_fn(
Expand Down

0 comments on commit 337a607

Please sign in to comment.