diff --git a/mace/tools/train.py b/mace/tools/train.py index 3bddfc6c..8ae628cf 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -429,15 +429,15 @@ def take_step_lbfgs( device: torch.device, ) -> Tuple[float, Dict[str, Any]]: start_time = time.time() - batch = batch.to(device) - batch_dict = batch.to_dict() def closure(): optimizer.zero_grad(set_to_none=True) - total_loss = 0.0 + total_loss = torch.tensor(0.0, device=device) - # Process each batch but then collect the results we pass to the optimizer + # Process each batch and then collect the results we pass to the optimizer for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() output = model( batch_dict, training=True, @@ -446,12 +446,11 @@ def closure(): compute_stress=output_args["stress"], ) batch_loss = loss_fn(pred=output, ref=batch) - batch_loss = batch_loss / len(data_loader) - + # Accumulate gradients without updating weights (remove for torchmin) batch_loss.backward() - total_loss += batch_loss.item() + total_loss += batch_loss if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)