Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] [7/N] API Server: Multiprocessing Detokenizer [ DO NOT MERGE ] #11636

Closed
5 changes: 4 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,17 @@ def main(args: argparse.Namespace):
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
total_input_tokens = total_num_tokens - total_output_tokens
if is_multi_modal:
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
f"{total_output_tokens / elapsed_time:.2f} output tokens/s, "
f"{total_input_tokens / len(requests)} input tokens/req, "
f"{(total_output_tokens) / len(requests)} output tokens/req")

# Output JSON results if specified
if args.output_json:
Expand Down
4 changes: 2 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,11 +1903,11 @@ def make_zmq_socket(
if type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
socket.bind(path)
elif type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")

Expand Down
13 changes: 9 additions & 4 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class EngineCoreOutputs(
outputs: List[EngineCoreOutput]


@dataclass
class EngineCoreAbort:
request_ids: List[str]


@dataclass
class EngineCoreProfile:
is_start: bool
Expand All @@ -66,9 +71,9 @@ class EngineCoreRequestType(enum.Enum):
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
FROM_ENGINE_CORE = b'\x00'
FROM_ENGINE = b'\x01'


EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreAbort]
163 changes: 108 additions & 55 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
# Copyright 2033-2024 The vLLM team.
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Inspired by https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/tokenizer_manager.py

import asyncio
import os
import pickle
import signal
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
import weakref
from typing import (Any, AsyncGenerator, Dict, List, Mapping, Optional, Type,
Union)

import zmq
import zmq.asyncio

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -18,9 +41,11 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import EngineCoreAbort, EngineCoreRequestType
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.detokenizer import DetokenizerProc
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor

Expand All @@ -41,6 +66,9 @@ def __init__(
log_requests: bool = True,
start_engine_loop: bool = True,
) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)

# The child processes will send SIGQUIT when unrecoverable
# errors happen. We kill the process tree here so that the
Expand All @@ -65,47 +93,59 @@ def sigquit_handler(signum, frame):
self.model_config = vllm_config.model_config

# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
tokenizer.ping()

# Request streams (map of request_id -> queue).
self.rid_to_queue: Dict[str, asyncio.Queue] = {}

# Processor (converts Inputs --> EngineCoreRequests).
# Processor (in this process).
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
tokenizer=tokenizer,
input_registry=input_registry,
)

# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(
tokenizer_name=vllm_config.model_config.tokenizer,
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
trust_remote_code=vllm_config.model_config.trust_remote_code,
revision=vllm_config.model_config.tokenizer_revision,
# Setup ZMQ IPC. Message flow is:
# AsyncLLM <-> Detokenizer <-> EngineCore
to_engine_core_path = get_open_zmq_ipc_path()
to_detokenizer_path = get_open_zmq_ipc_path()
from_detokenizer_path = get_open_zmq_ipc_path()
self.ctx = zmq.asyncio.Context(io_threads=2)
self.to_detokenizer = make_zmq_socket(self.ctx, to_detokenizer_path,
zmq.constants.PUSH)
self.from_detokenizer = make_zmq_socket(self.ctx,
from_detokenizer_path,
zmq.constants.PULL)

# Detokenizer (in background process).
self.detokenizer_handle = DetokenizerProc.make_process(
input_path=to_detokenizer_path,
output_path=from_detokenizer_path,
to_engine_core_path=to_engine_core_path,
tokenizer_name=self.model_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.revision,
)

# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
# EngineCore (in background process).
self.engine_core_handle = EngineCoreProc.make_process(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
input_path=to_engine_core_path,
output_path=to_detokenizer_path,
log_stats=log_stats,
)

self.output_handler: Optional[asyncio.Task] = None

def __del__(self):
self.shutdown()

@classmethod
def from_engine_args(
cls,
Expand Down Expand Up @@ -137,16 +177,24 @@ def from_engine_args(
)

def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
"""Shutdown, cleaning up the background procs and IPC."""
# ZMQ.
self.ctx.destroy(linger=0)

if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
# EngineCore background process.
if hasattr(self, "engine_core_handle"):
self.engine_core_handle.shutdown()

if handler := getattr(self, "output_handler", None):
handler.cancel()
# Detokenizer background process.
if hasattr(self, "engine_core_handle"):
self.engine_core_handle.shutdown()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
# Output handler background task.
if hasattr(self, "output_handler") and self.output_handler:
self.output_handler.cancel()

@staticmethod
def _get_executor_cls(vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
Expand Down Expand Up @@ -184,11 +232,10 @@ async def add_request(
prompt_adapter_request,
priority)

# 3) Add the request to Detokenizer (this process).
self.detokenizer.add_request(request)

# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
# 3) Send to Detokenizer (which forwards to EngineCore).
# note(rob): we forward the request rather than sending to each
# process separately to avoid race conditions.
await self._send_pyobj(self.to_detokenizer, request)

if self.log_requests:
logger.info("Added request %s.", request_id)
Expand Down Expand Up @@ -246,12 +293,12 @@ async def generate(
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
while True:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
# note(rob): drain queue without await if possible
# (avoids task switching under load for performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()

# Note: both Detokenizer and EngineCore handle their
# own request cleanup based on finished.
# notte(rob): both Detokenizer and EngineCore handle
# their own request cleanup based on finished.
if out.finished:
del self.rid_to_queue[request_id]
yield out
Expand Down Expand Up @@ -283,17 +330,16 @@ async def _run_output_handler(self):

try:
while True:
# 1) Pull EngineCoreOutput from the EngineCore.
outputs = await self.engine_core.get_output_async()

# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)
# note(rob): use socket directly to avoid calling await multiple
# times, which causes too much task switching at high QPS.
outputs: List[RequestOutput] = []
outputs = await self.from_detokenizer.recv_pyobj()

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)
for out in outputs:
# Note(rob): it is possible that a request was aborted
# due to cancellation, so we just skip if not found.
if out.request_id in self.rid_to_queue:
self.rid_to_queue[out.request_id].put_nowait(out)

except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
Expand All @@ -302,15 +348,23 @@ async def _run_output_handler(self):
async def abort(self, request_id: str) -> None:
"""Abort RequestId in self, detokenizer, and engine core."""

request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)
# Alert detokenizer that we have an abort (message is forwarded
# to the EngineCore).
await self._send_pyobj(self.to_detokenizer,
EngineCoreAbort([request_id]))

# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]

@staticmethod
async def _send_pyobj(socket: zmq.asyncio.Socket, obj: Any):
"""Send object to Detokenizer with a FROM_ENGINE flag."""

msg = (EngineCoreRequestType.FROM_ENGINE.value, pickle.dumps(obj))
await socket.send_multipart(msg, copy=False)

def encode(
self,
prompt: PromptType,
Expand All @@ -335,8 +389,7 @@ async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
assert lora_request is None
return self.detokenizer.tokenizer
return self.processor.tokenizer.get_lora_tokenizer(lora_request)

async def is_tracing_enabled(self) -> bool:
return False
Expand All @@ -352,10 +405,10 @@ async def check_health(self) -> None:
logger.debug("Called check_health.")

async def start_profile(self) -> None:
await self.engine_core.profile_async(True)
pass

async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
pass

@property
def is_running(self) -> bool:
Expand Down
Loading
Loading