Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Refactor get_executor_cls #11754

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.executor.abstract import Executor

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
Expand Down
21 changes: 1 addition & 20 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import initialize_ray_cluster

logger = init_logger(__name__)

Expand Down Expand Up @@ -105,7 +104,7 @@ def from_engine_args(
else:
vllm_config = engine_config

executor_class = cls._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

# Create the AsyncLLM.
return cls(
Expand All @@ -127,24 +126,6 @@ def shutdown(self):
if handler := getattr(self, "output_handler", None):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
initialize_ray_cluster(vllm_config.parallel_config)
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class

async def add_request(
self,
request_id: str,
Expand Down
20 changes: 1 addition & 19 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def from_engine_args(

# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
Expand All @@ -103,24 +103,6 @@ def from_engine_args(
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor

return executor_class

def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests()

Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple
from typing import Tuple, Type

from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -8,6 +8,23 @@
class Executor(ABC):
"""Abstract class for executors."""

@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class

@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError
Expand Down
Loading