Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Do async init of xgrammar in the engine #10871

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 4 additions & 35 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)

import cloudpickle
import psutil
import zmq
import zmq.asyncio
Expand All @@ -19,8 +17,6 @@
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
Expand Down Expand Up @@ -577,38 +573,12 @@ async def _process_request(
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)

# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
)

# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue

try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None

request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
Expand All @@ -620,12 +590,11 @@ async def _process_request(
priority=priority,
))

# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 2) Send the RPCGenerateRequest to the MQLLMEngine.
await self.input_socket.send_multipart((request_bytes, ),
copy=False)

# 4) Stream the RequestOutputs from the output queue. Note
# 3) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
Expand Down
8 changes: 1 addition & 7 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
import zmq

from vllm import AsyncEngineArgs, SamplingParams
from vllm import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -221,11 +220,6 @@ def handle_new_input(self):
request = pickle.loads(frames[0].buffer)

if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
Expand Down
44 changes: 30 additions & 14 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# noqa: UP007
from __future__ import annotations

import concurrent.futures
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, NamedTuple
Expand All @@ -20,18 +21,27 @@
from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams

_thread_pool = None


# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
max_threads: int = 8):

global _thread_pool
if _thread_pool is None:
_thread_pool = concurrent.futures.ThreadPoolExecutor(max_threads)

config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)
xgr_proc = XGrammarLogitsProcessor(config)
xgr_proc.async_init(_thread_pool)
return xgr_proc


class TokenizerData(NamedTuple):
Expand Down Expand Up @@ -184,6 +194,11 @@ class XGrammarLogitsProcessor:
batch_size: int = field(default=1)
prefilled: bool = field(default=False)

_future: concurrent.futures.Future[Any] | None = None

def async_init(self, thread_pool: concurrent.futures.ThreadPoolExecutor):
self._future = thread_pool.submit(self._init_ctx)

def __getstate__(self) -> dict[str, Any]:
return {'config': self.config}

Expand All @@ -196,24 +211,25 @@ def __setstate__(self, state: dict[str, Any]):
self.token_bitmask = None # type: ignore[assignment]
self.prefilled = False

def _ensure_ctx(self):
def _init_ctx(self):
"""Lazily initialize the processor in the worker process"""
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar()
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
return compiler.compile_json_schema(self.config.json_str)
elif self.config.grammar_str is not None:
return compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
return compiler.compile_builtin_json_grammar()
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")

def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
if self.ctx is None:
self._ensure_ctx()
assert self._future is not None
self.ctx = self._future.result()
self._future = None

if len(self.matchers) == 0:
self.matchers = [
Expand Down
Loading