Skip to content

Commit

Permalink
[Frontend] don't block event loop in tokenization (preprocess) in Ope…
Browse files Browse the repository at this point in the history
…nAI compatible server (vllm-project#10635)

Signed-off-by: Tomer Asida <[email protected]>
  • Loading branch information
tomeras91 authored and weilong.yu committed Dec 13, 2024
1 parent fa4bb00 commit ea04f55
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 56 deletions.
137 changes: 137 additions & 0 deletions tests/entrypoints/openai/test_async_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import asyncio
import contextlib
import random
import time
from typing import Callable

import openai
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--load-format",
"dummy",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["completion", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 10_000)
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 10_000)
}]
}),
],
)
async def test_with_and_without_truncate(
server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

num_requests = 10
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
(num_requests - num_requests // 2))
random.shuffle(truncate_prompt_tokens)

bodies = [{
**body, "extra_body": {
'truncate_prompt_tokens': t
}
} for t in truncate_prompt_tokens]

async def get_status_code(**kwargs):
try:
await create_func(**kwargs)
return 200
except openai.APIStatusError as e:
return e.status_code

responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
assert 500 not in responses


@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["single completion", "multiple completions", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 300_000)
}),
(lambda x: x.completions.create, {
"prompt": [" ".join(['A'] * 300_000)] * 2
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 300_000)
}]
}),
],
)
async def test_healthcheck_response_time(
server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
num_requests = 50

create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

def get_response_time(url):
start_time = time.monotonic()
res = requests.get(url)
end_time = time.monotonic()
assert res.status_code == 200
return end_time - start_time

no_load_response_time = get_response_time(server.url_for("health"))
tasks = [
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
]
await asyncio.sleep(1) # give the tasks a chance to start running
load_response_time = get_response_time(server.url_for("health"))

with contextlib.suppress(openai.APIStatusError):
await asyncio.gather(*tasks)

assert load_response_time < 100 * no_load_response_time
assert load_response_time < 0.1
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def create_completion(

tokenizer = await self.engine_client.get_tokenizer(lora_request)

request_prompts, engine_prompts = self._preprocess_completion(
request_prompts, engine_prompts = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ async def create_embedding(
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
Expand Down
75 changes: 40 additions & 35 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Expand Down Expand Up @@ -46,7 +47,7 @@
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
from vllm.utils import AtomicCounter, is_list_of, make_async

logger = init_logger(__name__)

Expand Down Expand Up @@ -140,6 +141,14 @@ def __init__(
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids

self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
Expand Down Expand Up @@ -368,53 +377,49 @@ def _tokenize_prompt_input_or_inputs(
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
) -> List[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
return [
self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
if prompt_input["is_tokens"] is False else
self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
]

def _preprocess_completion(
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)

engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
Expand Down Expand Up @@ -493,7 +498,7 @@ async def _preprocess_chat(
request=request)

if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
Expand Down
10 changes: 6 additions & 4 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import make_async, merge_async_iterators, random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -145,9 +145,11 @@ async def create_score(
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ async def create_tokenize(
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
Expand Down Expand Up @@ -134,7 +135,7 @@ async def create_detokenize(
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)

prompt_input = self._tokenize_prompt_input(
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
Expand Down
8 changes: 6 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import asyncio
import concurrent
import contextlib
import datetime
import enum
Expand Down Expand Up @@ -351,7 +352,10 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()


def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
def make_async(
func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None
) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
Expand All @@ -362,7 +366,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return loop.run_in_executor(executor=executor, func=p_func)

return _async_wrapper

Expand Down

0 comments on commit ea04f55

Please sign in to comment.