diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 4e27babf12cc3..3e9b0e10a11d8 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -62,8 +62,7 @@ def worker_fn(): with pynccl_comm.change_state(enable=True): tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == pynccl_comm.world_size + assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -88,13 +87,11 @@ def multiple_allreduce_worker_fn(): tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == 4 + assert torch.all(tensor == 4).cpu().item() else: tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == 2 + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -116,13 +113,11 @@ def multiple_allreduce_with_vllm_worker_fn(): tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == 4 + assert torch.all(tensor == 4).cpu().item() else: tensor = tensor_model_parallel_all_reduce(tensor) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == 2 + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -149,7 +144,7 @@ def worker_fn_with_cudagraph(): torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 + assert torch.all(a_out == pynccl_comm.world_size).cpu().item() @worker_fn_wrapper @@ -249,8 +244,7 @@ def send_recv_worker_fn(): src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() - result = tensor.mean().cpu().item() - assert result == 1 + assert torch.all(tensor == 1).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -289,11 +283,10 @@ def multiple_send_recv_worker_fn(): src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() - result = tensor.mean().cpu().item() if torch.distributed.get_rank() in [0, 2]: - assert result == 1 + assert torch.all(tensor == 1).cpu().item() else: - assert result == 2 + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4,