diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 6f1498a5d0852..b7390759d72f8 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -319,8 +319,7 @@ def broadcast_worker_fn(): for i in range(pynccl_comm.world_size): pynccl_comm.broadcast(recv_tensors[i], src=i) - result = recv_tensors[i].mean().cpu().item() - assert result == i + assert torch.all(recv_tensors[i] == i).cpu().item() def test_ncclGetUniqueId():