From a2dceca926cf570977b40f28e131363fe8c61d6d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:34:14 -0800 Subject: [PATCH] add fallback Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d71c84f7c098e..ccbe00386c5da 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -304,7 +304,8 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state(stream=torch.cuda.current_stream()) + maybe_pynccl_context = pynccl_comm.change_state( + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context @@ -360,8 +361,15 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: assert pynccl_comm is not None # TODO: pynccl should not use `stream=` # it can just always use the current stream. - out = pynccl_comm.all_reduce(input_, stream=torch.cuda.current_stream()) - assert out is not None + out = pynccl_comm.all_reduce(input_, + stream=torch.cuda.current_stream()) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: