Skip to content

Commit

Permalink
MQ engine: remove guided decoding init from the client
Browse files Browse the repository at this point in the history
Currently with MQLLMEngine, we are initializing LogitsProcessors
on the client side, pickling the entire list of LogitsProcessors,
and sending them over ZeroMQ to the engine.

This was put in place so that the expensive initialization (tens
of second) of the Outlines LogitsProcessor could happen in a thread,
such that the client could defer submitting the request to the engine
until the initialization had completed.

This became an issue because recent (Rust-based) Outlines does not
support pickle serialization, but this has resolved by
dottxt-ai/outlines-core#99.

However, this approach is also not desirable in the case of
XGrammar because the initialization is not expensive (hundreds of
milliseconds) and the serialization is just unnecessary complexity.

And so, let's remove the code from the client side of MQLLMEngine
to special case the creation of logits_processors based on guided
decoding params. This will now happen on the engine side once
again.

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc committed Dec 3, 2024
1 parent cb5c807 commit 46d3005
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 42 deletions.
39 changes: 4 additions & 35 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)

import cloudpickle
import psutil
import zmq
import zmq.asyncio
Expand All @@ -19,8 +17,6 @@
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
Expand Down Expand Up @@ -577,38 +573,12 @@ async def _process_request(
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)

# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
)

# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue

try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None

request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
Expand All @@ -620,12 +590,11 @@ async def _process_request(
priority=priority,
))

# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 2) Send the RPCGenerateRequest to the MQLLMEngine.
await self.input_socket.send_multipart((request_bytes, ),
copy=False)

# 4) Stream the RequestOutputs from the output queue. Note
# 3) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
Expand Down
8 changes: 1 addition & 7 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
import zmq

from vllm import AsyncEngineArgs, SamplingParams
from vllm import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -221,11 +220,6 @@ def handle_new_input(self):
request = pickle.loads(frames[0].buffer)

if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
Expand Down

0 comments on commit 46d3005

Please sign in to comment.