Skip to content

Commit

Permalink
[Frontend] Factor out chat message parsing (vllm-project#7055)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 3, 2024
1 parent 69ea15e commit 8c025fa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 27 deletions.
28 changes: 23 additions & 5 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import codecs
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)

# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
mm_futures: List[Awaitable[MultiModalDataDict]]


def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)


def parse_chat_message_content(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
Expand All @@ -190,3 +190,21 @@ def parse_chat_message_content(

return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)


def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)

conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)

return conversation, mm_futures
17 changes: 5 additions & 12 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union

Expand All @@ -11,7 +10,7 @@
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
Expand Down Expand Up @@ -92,15 +91,8 @@ async def create_chat_completion(
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)

conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)

conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)

tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
Expand All @@ -115,6 +107,7 @@ async def create_chat_completion(
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down
21 changes: 11 additions & 10 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import List, Optional, Union

from vllm.config import ModelConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
Expand All @@ -17,8 +15,11 @@
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)


class OpenAIServingTokenization(OpenAIServing):

Expand Down Expand Up @@ -62,12 +63,12 @@ async def create_tokenize(
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config

conversation: List[ConversationMessage] = []
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)

for message in request.messages:
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
if mm_futures:
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
Expand Down

0 comments on commit 8c025fa

Please sign in to comment.