From 2d1b9baa8f57fc59912c7bcd07fd630fb9d72c9d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Dec 2024 13:26:32 -0700 Subject: [PATCH] [Bugfix] Fix request cancellation without polling (#11190) --- tests/entrypoints/openai/test_basic.py | 51 ++++++++++++++++ tests/test_utils.py | 6 +- tests/utils.py | 11 ++-- vllm/engine/async_llm_engine.py | 46 +++++++++------ vllm/entrypoints/api_server.py | 11 ++-- vllm/entrypoints/openai/api_server.py | 8 +++ vllm/entrypoints/openai/serving_chat.py | 5 -- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/entrypoints/openai/serving_embedding.py | 5 +- vllm/entrypoints/openai/serving_score.py | 5 +- vllm/entrypoints/utils.py | 57 ++++++++++++++++++ vllm/utils.py | 59 ++----------------- 12 files changed, 164 insertions(+), 103 deletions(-) create mode 100644 vllm/entrypoints/utils.py diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 4616f363cc04a..547c1fd020928 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,6 +1,8 @@ +import asyncio from http import HTTPStatus from typing import List +import openai import pytest import pytest_asyncio import requests @@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer): response = requests.get(server.url_for("health")) assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server_args", + [ + pytest.param(["--max-model-len", "10100"], + id="default-frontend-multiprocessing"), + pytest.param( + ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_request_cancellation(server: RemoteOpenAIServer): + # clunky test: send an ungodly amount of load in with short timeouts + # then ensure that it still responds quickly afterwards + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client(timeout=0.5) + tasks = [] + # Request about 2 million tokens + for _ in range(200): + task = asyncio.create_task( + client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000})) + tasks.append(task) + + done, pending = await asyncio.wait(tasks, + return_when=asyncio.ALL_COMPLETED) + + # Make sure all requests were sent to the server and timed out + # (We don't want to hide other errors like 400s that would invalidate this + # test) + assert len(pending) == 0 + for d in done: + with pytest.raises(openai.APITimeoutError): + d.result() + + # If the server had not cancelled all the other requests, then it would not + # be able to respond to this one within the timeout + client = server.get_async_client(timeout=5) + response = await client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10) + + assert len(response.choices) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0bc9e5bc32a46..32a6b0aed66aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import asyncio import os import socket -from functools import partial from typing import AsyncIterator, Tuple import pytest @@ -26,10 +25,7 @@ async def mock_async_iterator(idx: int): print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators, - is_cancelled=partial(asyncio.sleep, - 0, - result=False)) + merged_iterator = merge_async_iterators(*iterators) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/tests/utils.py b/tests/utils.py index afeb708f3bcdc..bf3d88194e4ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -163,12 +163,11 @@ def get_client(self): api_key=self.DUMMY_API_KEY, ) - def get_async_client(self): - return openai.AsyncOpenAI( - base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - ) + def get_async_client(self, **kwargs): + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) def _test_completion( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 32396fd10188d..f50e20cf70323 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1065,16 +1065,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ): - yield LLMEngine.validate_output(output, RequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def encode( self, @@ -1147,15 +1151,19 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index ea3c93f733038..95da1c6e7b9bf 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -17,11 +17,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, - random_uuid) +from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -47,6 +47,11 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) @@ -54,8 +59,6 @@ async def generate(request: Request) -> Response: assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) - results_generator = iterate_with_cancellation( - results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 14e3a34ce141c..00e2d1a56f160 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, @@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") +@with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") +@with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -353,6 +356,7 @@ async def show_version(): @router.post("/v1/chat/completions") +@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") +@with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") +@with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: @@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/score") +@with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: @@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post("/v1/score") +@with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 527418c635093..81bce0dd370bb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,7 +32,6 @@ from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls -from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -234,10 +233,6 @@ async def create_chat_completion( assert len(generators) == 1 result_generator, = generators - if raw_request: - result_generator = iterate_with_cancellation( - result_generator, raw_request.is_disconnected) - # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index bd39a4c42e938..5cf9df92e296e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -159,8 +159,7 @@ async def create_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators(*generators) model_name = self._get_model_name(lora_request) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index fd501ad4f833e..879276646d2ba 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -202,10 +202,7 @@ async def create_embedding( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 6f5cc14ac37cc..101d170bee4d6 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -186,10 +186,7 @@ async def create_score( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 0000000000000..e8a78d216d0f0 --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,57 @@ +import asyncio +import functools + +from fastapi import Request + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/vllm/utils.py b/vllm/utils.py index 73d2ae25f15ca..38c7dea6d2d3d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import uuid import warnings import weakref -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field @@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None], return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] -async def iterate_with_cancellation( - iterator: AsyncGenerator[T, None], - is_cancelled: Callable[[], Awaitable[bool]], -) -> AsyncGenerator[T, None]: - """Convert async iterator into one that polls the provided function - at least once per second to check for client cancellation. - """ - - loop = asyncio.get_running_loop() - - awaits: List[Future[T]] = [_next_task(iterator, loop)] - next_cancel_check: float = 0 - while True: - done, pending = await asyncio.wait(awaits, timeout=1.5) - - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - with contextlib.suppress(BaseException): - awaits[0].cancel() - await iterator.aclose() - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 - - if done: - try: - item = await awaits[0] - awaits[0] = _next_task(iterator, loop) - yield item - except StopAsyncIteration: - # we are done - return - - async def merge_async_iterators( - *iterators: AsyncGenerator[T, None], - is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, -) -> AsyncGenerator[Tuple[int, T], None]: + *iterators: AsyncGenerator[T, + None], ) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. - - It also optionally polls a provided function at least once per second - to check for client cancellation. """ loop = asyncio.get_running_loop() awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} - timeout = None if is_cancelled is None else 1.5 - next_cancel_check: float = 0 try: while awaits: - done, pending = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED, - timeout=timeout) - if is_cancelled is not None: - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 + done, _ = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: