diff --git a/vllm/distributed/device_communicators/base_communicator.py b/vllm/distributed/device_communicators/base_communicator.py new file mode 100644 index 0000000000000..e564a4713406c --- /dev/null +++ b/vllm/distributed/device_communicators/base_communicator.py @@ -0,0 +1,110 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class CommunicatorABC(ABC): + """ + CommunicatorBase ABC + """ + + @abstractmethod + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + raise NotImplementedError + + +class CommunicatorBase(CommunicatorABC): + """ + Base class for device-specific communicator. + + The class provides primary communication ops that are frequently + overridden by devices. And it provides default implementations of + these communication ops, which will be used when there is no + device-specific communicator implementations. + """ + + def __init__(self, group: ProcessGroup, unique_name: str = ""): + self.group = group + self.unique_name = unique_name + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(self.group) + self.ranks = dist.get_process_group_ranks(self.group) + global_rank = dist.get_rank() + self.rank_in_group = dist.get_group_rank(self.group, global_rank) + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + All-reduce function only for cpu and gpu. + + NOTE: `torch.ops.vllm.all_reduce` is registered in + `vllm/distributed/parallel_state.py` + """ + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.group) + return input_ + + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [ + torch.empty_like(input_) for _ in range(self.world_size) + ] + else: + gather_list = None + # Gather. + dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, input_, group=self.group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py index cc9b19ce022b5..302c1cc550f06 100644 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -1,44 +1,37 @@ import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from vllm.distributed.device_communicators.base_communicator import ( + CommunicatorBase) from vllm.platforms import current_platform if current_platform.is_hpu(): import habana_frameworks.torch as htorch # noqa: F401 -class HpuCommunicator: +class HpuCommunicator(CommunicatorBase): - def __init__(self, group: ProcessGroup): - if not current_platform.is_hpu(): - self.disabled = True - return - self.disabled = False - self.group = group - self.world_size = dist.get_world_size(self.group) - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used # (which is required for tensor parallel HPUGraph inference) htorch.core.mark_step() - dist.all_reduce(x, group=self.group) - return x + dist.all_reduce(input_, group=self.group) + return input_ - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size if dim < 0: # Convert negative dim to positive. - dim += x.dim() - input_size = x.size() + dim += input_.dim() + input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty((world_size, ) + input_size, - dtype=x.dtype, - device=x.device) + dtype=input_.dtype, + device=input_.device) # All-gather. htorch.core.mark_step() - dist.all_gather_into_tensor(output_tensor, x, group=self.group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.group) # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 765a0f9cb1c87..06f235ed6f4b2 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,9 +1,10 @@ import os import torch -import torch.distributed as dist from torch.distributed import ProcessGroup +from vllm.distributed.device_communicators.base_communicator import ( + CommunicatorBase) from vllm.platforms import current_platform if current_platform.is_tpu(): @@ -14,19 +15,10 @@ from vllm.executor import ray_utils -class TpuCommunicator: +class TpuCommunicator(CommunicatorBase): - def __init__(self, group: ProcessGroup): - if not current_platform.is_tpu(): - self.disabled = True - return - self.disabled = False - - # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node - # must be used together. Therefore, the local rank and world size can - # be simply calculated as follows. - global_rank = dist.get_rank(group) - global_world_size = dist.get_world_size(group) + def __init__(self, group: ProcessGroup, unique_name: str = ""): + super().__init__(group, unique_name) # Calculate how many TPU nodes are in the current deployment. This # is the Ray placement group if it is deployed with Ray. Default @@ -39,23 +31,23 @@ def __init__(self, group: ProcessGroup): if num_nodes_in_pg > 0: num_nodes = num_nodes_in_pg - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size + local_world_size = self.world_size // num_nodes + local_rank = self.rank % local_world_size # Ensure environment variables are set for multihost deployments. # On GKE, this is needed for libtpu and TPU driver to know which TPU # chip is actually visible. Otherwise the TPU driver will fail to # initialize because the number of devices would be different from # the number of visible worker addresses. - os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["CLOUD_TPU_TASK_ID"] = str(self.rank) os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, input_) - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) + return xm.all_gather(input_, dim=dim) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index eafd3c2f67749..ecc5648e05fb4 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -1,29 +1,17 @@ import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from vllm.platforms import current_platform +from vllm.distributed.device_communicators.base_communicator import ( + CommunicatorBase) -class XpuCommunicator: +class XpuCommunicator(CommunicatorBase): - def __init__(self, group: ProcessGroup): - if not current_platform.is_xpu(): - self.disabled = True - return - self.disabled = False - self.group = group - self.world_size = dist.get_world_size(self.group) + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + dist.all_reduce(input_, group=self.group) + return input_ - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - dist.all_reduce(x, group=self.group) - return x - - def gather(self, - input_: torch.Tensor, - rank_in_group: int, - dst: int = 0, - dim: int = -1): + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1): # For xpu path, gather doesn't work properly together with ray # cluster so we use all_gather instead for now. input_size = input_.size() @@ -35,7 +23,7 @@ def gather(self, torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.group) - if rank_in_group == dst: + if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bf8b30cccd5f6..15afcb422df6d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,7 +39,8 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import direct_register_custom_op, supports_custom_op +from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, + supports_custom_op) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -162,9 +163,7 @@ def __init__( torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, - use_tpu_communicator: bool, - use_hpu_communicator: bool, - use_xpu_communicator: bool, + use_device_specific_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -180,6 +179,7 @@ def __init__( for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -201,11 +201,11 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce - self.use_tpu_communicator = use_tpu_communicator - self.use_hpu_communicator = use_hpu_communicator - self.use_xpu_communicator = use_xpu_communicator + self.use_device_specific_communicator = use_device_specific_communicator # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.base_communicator import ( + CommunicatorABC) from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( @@ -226,23 +226,12 @@ def __init__( device=self.device, ) - from vllm.distributed.device_communicators.tpu_communicator import ( - TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] = None - if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator(group=self.cpu_group) - - from vllm.distributed.device_communicators.hpu_communicator import ( - HpuCommunicator) - self.hpu_communicator: Optional[HpuCommunicator] - if use_hpu_communicator and self.world_size > 1: - self.hpu_communicator = HpuCommunicator(group=self.device_group) - - from vllm.distributed.device_communicators.xpu_communicator import ( - XpuCommunicator) - self.xpu_communicator: Optional[XpuCommunicator] - if use_xpu_communicator and self.world_size > 1: - self.xpu_communicator = XpuCommunicator(group=self.device_group) + self.communicator: CommunicatorABC + if use_device_specific_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + self.communicator = device_comm_cls(group=self.device_group, + unique_name=self.unique_name) from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) @@ -326,27 +315,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ - if input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) - return input_ - - if self.tpu_communicator is not None and \ - not self.tpu_communicator.disabled: - # TPU handles Dynamo with its own logic. - return self.tpu_communicator.all_reduce(input_) - - if self.hpu_communicator is not None and \ - not self.hpu_communicator.disabled: - return self.hpu_communicator.all_reduce(input_) - - if self.xpu_communicator is not None and \ - not self.xpu_communicator.disabled: - return self.xpu_communicator.all_reduce(input_) - - return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) + return self.communicator.all_reduce(input_) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + """ + Out place all-reduce operation, which is registered in `torch.ops.vllm` + """ # always try custom allreduce first, # and then pynccl. ca_comm = self.ca_comm @@ -374,41 +348,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_gather(input_, dim) - - # For HPUs, use HPU communicator. - hpu_comm = self.hpu_communicator - if hpu_comm is not None and not hpu_comm.disabled: - return hpu_comm.all_gather(input_, dim) - - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # NOTE: we have to use concat-style all-gather here, - # stack-style all-gather has compatibility issues with - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * world_size, ) + input_size[1:] - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - # Reshape - output_tensor = output_tensor.reshape((world_size, ) + input_size) - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor + return self.communicator.all_gather(input_, dim) def gather(self, input_: torch.Tensor, @@ -428,25 +368,8 @@ def gather(self, if dim < 0: # Convert negative dim to positive. dim += input_.dim() - if self.xpu_communicator is not None and \ - not self.xpu_communicator.disabled: - return self.xpu_communicator.gather(input_, self.rank_in_group, - dst, dim) - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor + + return self.communicator.gather(input_, dst, dim) def broadcast(self, input_: torch.Tensor, src: int = 0): """Broadcast the input tensor. @@ -845,9 +768,7 @@ def init_world_group(ranks: List[int], local_rank: int, torch_distributed_backend=backend, use_pynccl=False, use_custom_allreduce=False, - use_tpu_communicator=False, - use_hpu_communicator=False, - use_xpu_communicator=False, + use_device_specific_communicator=False, group_name="world", ) @@ -870,9 +791,7 @@ def init_model_parallel_group( use_pynccl=current_platform.is_cuda_alike(), use_custom_allreduce=current_platform.is_cuda_alike() and use_custom_allreduce, - use_tpu_communicator=True, - use_hpu_communicator=True, - use_xpu_communicator=True, + use_device_specific_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8350177b68ade..2e2bd5c787558 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -247,6 +247,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Flash Attention backend.") return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.base_communicator.CommunicatorBase" # noqa: E501" + @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 242c2c127979a..14e03b75e7875 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -33,6 +33,10 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa: E501 + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f2ecec3203fb7..07a47286bae29 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -293,6 +293,13 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator for distributed communication. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 67a9e816cb658..e5cef39e66506 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -93,6 +93,10 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.base_communicator.CommunicatorBase" # noqa: E501" + @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 05a3aa4305cfa..45f438f32bf5e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -47,6 +47,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa: E501 + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c34b5b58672e7..0656a8a9f5d71 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -52,6 +52,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa: E501 + @staticmethod def inference_mode(): return torch.no_grad()