Skip to content

Commit

Permalink
[Distributed][refactor] Add base class for device-specific communicator
Browse files Browse the repository at this point in the history
Signed-off-by: Mengqing Cao <[email protected]>
  • Loading branch information
MengqingCao committed Jan 15, 2025
1 parent 994fc65 commit b085f82
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 162 deletions.
110 changes: 110 additions & 0 deletions vllm/distributed/device_communicators/base_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup


class CommunicatorABC(ABC):
"""
CommunicatorBase ABC
"""

@abstractmethod
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
raise NotImplementedError

@abstractmethod
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
raise NotImplementedError


class CommunicatorBase(CommunicatorABC):
"""
Base class for device-specific communicator.
The class provides primary communication ops that are frequently
overridden by devices. And it provides default implementations of
these communication ops, which will be used when there is no
device-specific communicator implementations.
"""

def __init__(self, group: ProcessGroup, unique_name: str = ""):
self.group = group
self.unique_name = unique_name
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(self.group)
self.ranks = dist.get_process_group_ranks(self.group)
global_rank = dist.get_rank()
self.rank_in_group = dist.get_group_rank(self.group, global_rank)

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
All-reduce function only for cpu and gpu.
NOTE: `torch.ops.vllm.all_reduce` is registered in
`vllm/distributed/parallel_state.py`
"""
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.group)
return input_

return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)

def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""

# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [
torch.empty_like(input_) for _ in range(self.world_size)
]
else:
gather_list = None
# Gather.
dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
31 changes: 12 additions & 19 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,37 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.base_communicator import (
CommunicatorBase)
from vllm.platforms import current_platform

if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401


class HpuCommunicator:
class HpuCommunicator(CommunicatorBase):

def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
dist.all_reduce(input_, group=self.group)
return input_

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype,
device=x.device)
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
Expand Down
32 changes: 12 additions & 20 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.base_communicator import (
CommunicatorBase)
from vllm.platforms import current_platform

if current_platform.is_tpu():
Expand All @@ -14,19 +15,10 @@
from vllm.executor import ray_utils


class TpuCommunicator:
class TpuCommunicator(CommunicatorBase):

def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False

# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
def __init__(self, group: ProcessGroup, unique_name: str = ""):
super().__init__(group, unique_name)

# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
Expand All @@ -39,23 +31,23 @@ def __init__(self, group: ProcessGroup):
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg

local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
local_world_size = self.world_size // num_nodes
local_rank = self.rank % local_world_size

# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["CLOUD_TPU_TASK_ID"] = str(self.rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)

pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
return xm.all_gather(input_, dim=dim)
28 changes: 8 additions & 20 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,17 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform
from vllm.distributed.device_communicators.base_communicator import (
CommunicatorBase)


class XpuCommunicator:
class XpuCommunicator(CommunicatorBase):

def __init__(self, group: ProcessGroup):
if not current_platform.is_xpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.group)
return input_

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self,
input_: torch.Tensor,
rank_in_group: int,
dst: int = 0,
dim: int = -1):
def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
Expand All @@ -35,7 +23,7 @@ def gather(self,
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.group)
if rank_in_group == dst:
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
Expand Down
Loading

0 comments on commit b085f82

Please sign in to comment.