Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] clean up executor class hierarchy between v1 and v0 #12171

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
b = min([r[1] for r in results])
return a, b

def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
For V1 compatibility.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
Expand Down
87 changes: 58 additions & 29 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,92 @@
from abc import ABC, abstractmethod
from typing import Type

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput


class Executor(ABC):
"""Abstract class for executors."""
class Executor(ExecutorBase):
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""

@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend)
if distributed_executor_backend is None:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if parallel_config.world_size > 1:
distributed_executor_backend = "mp"
else:
distributed_executor_backend = "uni"

if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}")
return executor_class

@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError

@abstractmethod
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")

@abstractmethod
def determine_available_memory(self) -> int: # in bytes
raise NotImplementedError
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)

@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
raise NotImplementedError
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]

@abstractmethod
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
raise NotImplementedError
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
return output[0]

@abstractmethod
def profile(self, is_start: bool = True):
raise NotImplementedError
self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass

@abstractmethod
def shutdown(self):
pass

@abstractmethod
def check_health(self) -> None:
raise NotImplementedError
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
pass
53 changes: 5 additions & 48 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand All @@ -35,9 +34,9 @@
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000


class MultiprocExecutor(Executor):
class MultiprocExecutor(ExecutorBase, Executor):

def __init__(self, vllm_config: VllmConfig) -> None:
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
Expand All @@ -55,9 +54,6 @@ def sigusr1_handler(signum, frame):

signal.signal(signal.SIGUSR1, sigusr1_handler)

self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config

self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
Expand All @@ -82,7 +78,8 @@ def sigusr1_handler(signum, frame):
# Create workers
self.workers: List[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
Expand All @@ -93,34 +90,6 @@ def sigusr1_handler(signum, frame):
for w in self.workers:
w.worker_response_mq.wait_until_ready()

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int:
"""
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
memory_sizes = self.collective_rpc("determine_available_memory")

# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(memory_sizes)

def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
Expand Down Expand Up @@ -172,18 +141,6 @@ def collective_rpc(self,
# Re-raise any other exceptions
raise e

def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
args=(scheduler_output, ))[0]
return model_output

def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return

def _ensure_worker_termination(self):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
Expand Down
Loading
Loading