diff --git a/mace/tools/train.py b/mace/tools/train.py index 600d93ad..e0b561b3 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -458,10 +458,9 @@ def closure(): torch.distributed.barrier() return total_loss - if torch.distributed.get_rank() == 0: - loss = optimizer.step(closure) - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) + loss = optimizer.step(closure) + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) if ema is not None: ema.update()