Skip to content

Commit

Permalink
add fallback
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 26, 2024
1 parent 100d26c commit a2dceca
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def graph_capture(
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(stream=torch.cuda.current_stream())
maybe_pynccl_context = pynccl_comm.change_state(
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context

Expand Down Expand Up @@ -360,8 +361,15 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_, stream=torch.cuda.current_stream())
assert out is not None
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
Expand Down

0 comments on commit a2dceca

Please sign in to comment.