Skip to content

Commit

Permalink
Replace mean with torch.all in test_pynccl.py
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Dec 3, 2024
1 parent 3bc94ca commit d5a27f8
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def worker_fn():
with pynccl_comm.change_state(enable=True):
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -88,13 +87,11 @@ def multiple_allreduce_worker_fn():
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2
assert torch.all(tensor == 2).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand All @@ -116,13 +113,11 @@ def multiple_allreduce_with_vllm_worker_fn():
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
assert torch.all(tensor == 4).cpu().item()
else:
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2
assert torch.all(tensor == 2).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand All @@ -149,7 +144,7 @@ def worker_fn_with_cudagraph():
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
assert torch.all(a_out == pynccl_comm.world_size).cpu().item()


@worker_fn_wrapper
Expand Down Expand Up @@ -249,8 +244,7 @@ def send_recv_worker_fn():
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 1
assert torch.all(tensor == 1).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand Down Expand Up @@ -289,11 +283,10 @@ def multiple_send_recv_worker_fn():
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1
assert torch.all(tensor == 1).cpu().item()
else:
assert result == 2
assert torch.all(tensor == 2).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand Down

0 comments on commit d5a27f8

Please sign in to comment.