Skip to content

Commit

Permalink
Adding valid error logging and disabling EMA during LBFGS step
Browse files Browse the repository at this point in the history
  • Loading branch information
vue1999 committed Nov 22, 2024
1 parent 95cc211 commit e9ba22d
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def train(
epoch=10000 # TODO: fix the code instead of using this workaround
if distributed:
train_sampler.set_epoch(epoch)

lbfgs_optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05)
print("HISTORY:20")
lbfgs_optimizer = Minimizer(model.parameters(), method='l-bfgs', tol=1e-05, options={'history_size': 20})
train_one_epoch(
model=model,
loss_fn=loss_fn,
Expand All @@ -276,7 +276,7 @@ def train(
epoch=epoch,
output_args=output_args,
max_grad_norm=max_grad_norm,
ema=ema,
ema=None,
logger=logger,
device=device,
distributed_model=distributed_model,
Expand All @@ -286,16 +286,40 @@ def train(

if distributed:
torch.distributed.barrier()


model_to_evaluate = (
model if distributed_model is None else distributed_model
)
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)

with param_context:
valid_loss = 0.0
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
)

for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model_to_evaluate,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
if rank == 0:
valid_err_log(
valid_loss_head,
eval_metrics,
logger,
log_errors,
epoch,
valid_loader_name,
)


logging.info("Training complete")

Expand Down

0 comments on commit e9ba22d

Please sign in to comment.