Skip to content

Commit

Permalink
[Misc][TPU] Support TPU in initialize_ray_cluster (vllm-project#6812)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 26, 2024
1 parent 71734f1 commit aa48677
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -93,32 +93,38 @@ def initialize_ray_cluster(
# Placement group is already set.
return

device_str = "GPU" if not is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
device_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
bundle_devices = bundle.get(device_str, 0)
if bundle_devices > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_devices:
device_bundles += 1
if parallel_config.world_size > device_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group."
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_devices_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
Expand Down

0 comments on commit aa48677

Please sign in to comment.