From aa4867791ecd73a5f55b7bad4d9372954e661fe4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 12:39:49 -0700 Subject: [PATCH] [Misc][TPU] Support TPU in initialize_ray_cluster (#6812) --- vllm/executor/ray_utils.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index fcbfa30d7a38a..58b864070f727 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -93,32 +93,38 @@ def initialize_ray_cluster( # Placement group is already set. return + device_str = "GPU" if not is_tpu() else "TPU" # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: # We are in a placement group bundles = current_placement_group.bundle_specs # Verify that we can use the placement group. - gpu_bundles = 0 + device_bundles = 0 for bundle in bundles: - bundle_gpus = bundle.get("GPU", 0) - if bundle_gpus > 1: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: raise ValueError( - "Placement group bundle cannot have more than 1 GPU.") - if bundle_gpus: - gpu_bundles += 1 - if parallel_config.world_size > gpu_bundles: + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the placement group.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") else: - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) - if parallel_config.world_size > num_gpus_in_cluster: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the cluster.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + placement_group_specs = ([{ + device_str: 1 + }] * parallel_config.world_size) current_placement_group = ray.util.placement_group( placement_group_specs) # Wait until PG is ready - this will block until all