Skip to content

Commit

Permalink
fix tests (update for out-of-place allreduce)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 26, 2024
1 parent a2dceca commit c50abd2
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def worker_fn():
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size

Expand All @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.all_reduce(tensor)
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2

Expand Down Expand Up @@ -140,14 +140,12 @@ def worker_fn_with_cudagraph():
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
pynccl_comm.all_reduce(a)
a_out = pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**0
assert a_out.mean().cpu().item() == pynccl_comm.world_size**0
graph.replay()
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**1
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1


@worker_fn_wrapper
Expand Down

0 comments on commit c50abd2

Please sign in to comment.