diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7f1ca621d91c4..882742c2fc61b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -112,7 +112,11 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. @@ -247,6 +251,9 @@ async def run_output_handler_loop(self): async def setup(self): """Setup the client before it starts sending server requests.""" + # Start output_loop + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + with self.get_data_socket() as socket: # Wait until server is ready. response = await self._wait_for_server_rpc(socket) @@ -265,7 +272,8 @@ def close(self): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() - self.output_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() def _set_errored(self, e: BaseException): logger.exception(repr(e)) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a73b4c825b11c..9dd6fa5b14315 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -349,16 +349,22 @@ def stop_profile(self) -> None: self.engine.model_executor._run_workers("stop_profile") +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, - ipc_path: str): + ipc_path: str, engine_alive): + try: + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) - def signal_handler(*_) -> None: - # Interrupt server on sigterm - raise KeyboardInterrupt("MQLLMEngine terminated") + signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) + engine.start() - engine = MQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path) - engine.start() + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 95fd56d916050..bef36ffdbfcd3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -171,39 +171,44 @@ async def build_async_engine_client_from_engine_args( # so we need to spawn a new process context = multiprocessing.get_context("spawn") + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) engine_process = context.Process(target=run_mp_engine, args=(engine_args, UsageContext.OPENAI_API_SERVER, - ipc_path)) + ipc_path, engine_alive)) engine_process.start() engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start" + assert engine_pid is not None, "Engine process failed to start." logger.info("Started engine process with PID %d", engine_pid) # Build RPCClient, which conforms to EngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) engine_config = engine_args.create_engine_config() - mp_engine_client = MQLLMEngineClient(ipc_path, engine_config, - engine_pid) - + build_client = partial(MQLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) try: while True: try: - await mp_engine_client.setup() + await mq_engine_client.setup() break except TimeoutError: - if not engine_process.is_alive(): + if (not engine_process.is_alive() + or not engine_alive.value): raise RuntimeError( - "Engine process failed to start") from None + "Engine process failed to start. See stack " + "trace for the root cause.") from None - yield mp_engine_client # type: ignore[misc] + yield mq_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated engine_process.terminate() # Close all open connections to the backend - mp_engine_client.close() + mq_engine_client.close() # Wait for engine process to join engine_process.join(4)