Skip to content

Commit

Permalink
fix stream sync
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 3, 2024
1 parent e116a89 commit 32b03fa
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def worker_fn():
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size

Expand All @@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2

Expand All @@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
else:
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2

Expand All @@ -141,9 +146,9 @@ def worker_fn_with_cudagraph():
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
a_out = pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
torch.cuda.synchronize()
graph.replay()
pynccl_comm.stream.synchronize()
torch.cuda.synchronize()
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1


Expand All @@ -170,6 +175,7 @@ def all_gather_worker_fn():

with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


Expand Down Expand Up @@ -207,6 +213,7 @@ def reduce_scatter_worker_fn():

with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


Expand Down Expand Up @@ -241,6 +248,7 @@ def send_recv_worker_fn():
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 1

Expand Down Expand Up @@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1
Expand Down Expand Up @@ -319,6 +328,9 @@ def broadcast_worker_fn():

for i in range(pynccl_comm.world_size):
pynccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.cuda.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()


Expand Down

0 comments on commit 32b03fa

Please sign in to comment.