diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 564d8a8343bef..0696caf88385d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -22,6 +22,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.ray_utils import initialize_ray_cluster logger = init_logger(__name__) @@ -131,7 +132,11 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: executor_class: Type[Executor] distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) - if distributed_executor_backend == "mp": + if distributed_executor_backend == "ray": + initialize_ray_cluster(vllm_config.parallel_config) + from vllm.v1.executor.ray_executor import RayExecutor + executor_class = RayExecutor + elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor else: