diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index d50d557a9545b..dd6f200962351 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -198,6 +198,8 @@ def _get_ipc_tensors(self, inp: torch.Tensor) -> List[torch.Tensor]: for i, obj in enumerate(all_meta): func = obj[0][0] args = list(obj[0][1]) + # This might break in the future since what `args` encompasses + # may change. args[6] = inp.device.index if i != self.rank: all_tensors.append(func(*args))