Skip to content

Commit

Permalink
fixes for the lbfgs step
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa committed Dec 22, 2024
1 parent 76b333e commit 246655c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,15 @@ def take_step_lbfgs(
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
batch = batch.to(device)
batch_dict = batch.to_dict()

def closure():
optimizer.zero_grad(set_to_none=True)
total_loss = 0.0
total_loss = torch.tensor(0.0, device=device)

# Process each batch but then collect the results we pass to the optimizer
# Process each batch and then collect the results we pass to the optimizer
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=True,
Expand All @@ -446,12 +446,11 @@ def closure():
compute_stress=output_args["stress"],
)
batch_loss = loss_fn(pred=output, ref=batch)

batch_loss = batch_loss / len(data_loader)

# Accumulate gradients without updating weights (remove for torchmin)
batch_loss.backward()
total_loss += batch_loss.item()
total_loss += batch_loss

if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
Expand Down

0 comments on commit 246655c

Please sign in to comment.