From 0d07f68322d7d6647d52d3623ed84bccb1551755 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 25 Nov 2024 13:59:04 -0500 Subject: [PATCH] MQ engine: remove guided decoding init from the client Currently with MQLLMEngine, we are initializing LogitsProcessors on the client side, pickling the entire list of LogitsProcessors, and sending them over ZeroMQ to the engine. This was put in place so that the expensive initialization (tens of second) of the Outlines LogitsProcessor could happen in a thread, such that the client could defer submitting the request to the engine until the initialization had completed. This became an issue because recent (Rust-based) Outlines does not support pickle serialization, but this has resolved by dottxt-ai/outlines-core#99. However, this approach is also not desirable in the case of XGrammar because the initialization is not expensive (hundreds of milliseconds) and the serialization is just unnecessary complexity. And so, let's remove the code from the client side of MQLLMEngine to special case the creation of logits_processors based on guided decoding params. This will now happen on the engine side once again. Signed-off-by: Mark McLoughlin --- vllm/engine/multiprocessing/client.py | 39 +++------------------------ vllm/engine/multiprocessing/engine.py | 8 +----- 2 files changed, 5 insertions(+), 42 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index d21136c03d7d2..ce59102c39085 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -1,11 +1,9 @@ import asyncio -import copy import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Union, cast, overload) -import cloudpickle import psutil import zmq import zmq.asyncio @@ -19,8 +17,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable -from vllm.engine.async_llm_engine import ( - build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, @@ -577,38 +573,12 @@ async def _process_request( if self._errored_with is not None: raise ENGINE_DEAD_ERROR(self._errored_with) - # Constructing guided decoding logits processors is expensive, so we do - # it here to avoid contending with cpu resources and the GIL on the - # backend process. - if isinstance(params, SamplingParams) and \ - params.guided_decoding is not None: - params = await \ - build_guided_decoding_logits_processor_async( - sampling_params=params, - tokenizer=await self.get_tokenizer(lora_request), - default_guided_backend=(self.decoding_config.guided_decoding_backend - if self.decoding_config - else DecodingConfig.guided_decoding_backend), - model_config=self.model_config - ) - # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue try: - # 2) Detach logits processors so that they can be pickled - # separately (may require cloudpickle which is slower) - if isinstance(params, SamplingParams) and params.logits_processors: - # Defensive shallow copy - params = copy.copy(params) - logits_processors = params.logits_processors - params.logits_processors = None - lp_bytes = cloudpickle.dumps(logits_processors) - else: - lp_bytes = None - request_bytes = pickle.dumps( RPCProcessRequest( prompt=prompt, @@ -620,12 +590,11 @@ async def _process_request( priority=priority, )) - # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) - await self.input_socket.send_multipart(parts, copy=False) + # 2) Send the RPCGenerateRequest to the MQLLMEngine. + await self.input_socket.send_multipart((request_bytes, ), + copy=False) - # 4) Stream the RequestOutputs from the output queue. Note + # 3) Stream the RequestOutputs from the output queue. Note # that the output_loop pushes RequestOutput objects to this # queue after pulling them from the zmq socket. finished = False diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 49a90b321dac4..8caf63ce42fd7 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -3,10 +3,9 @@ from contextlib import contextmanager from typing import Iterator, List, Optional, Union -import cloudpickle import zmq -from vllm import AsyncEngineArgs, SamplingParams +from vllm import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable @@ -221,11 +220,6 @@ def handle_new_input(self): request = pickle.loads(frames[0].buffer) if isinstance(request, RPCProcessRequest): - if len(frames) > 1: - # Use cloudpickle for logits processors - assert isinstance(request.params, SamplingParams) - lprocs = cloudpickle.loads(frames[1].buffer) - request.params.logits_processors = lprocs self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request)