From 8c025fa7030350a81bfeb665c99ad622667bdac0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 12:31:27 +0800 Subject: [PATCH] [Frontend] Factor out chat message parsing (#7055) --- vllm/entrypoints/chat_utils.py | 28 +++++++++++++++---- vllm/entrypoints/openai/serving_chat.py | 17 ++++------- .../openai/serving_tokenization.py | 21 +++++++------- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index fbb7f70b55e16..072450a6146ee 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,8 @@ import codecs -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, Union, cast, final +from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, + final) # yapf conflicts with isort for this block # yapf: disable @@ -65,8 +66,7 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] = field( - default_factory=list) + mm_futures: List[Awaitable[MultiModalDataDict]] def load_chat_template(chat_template: Optional[str]) -> Optional[str]: @@ -174,7 +174,7 @@ def _parse_chat_message_content_parts( return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) -def parse_chat_message_content( +def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, tokenizer: PreTrainedTokenizer, @@ -190,3 +190,21 @@ def parse_chat_message_content( return _parse_chat_message_content_parts(role, content, model_config, tokenizer) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, +) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] + + for msg in messages: + parse_result = _parse_chat_message_content(msg, model_config, + tokenizer) + + conversation.extend(parse_result.messages) + mm_futures.extend(parse_result.mm_futures) + + return conversation, mm_futures diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ebb1d57fbb9a6..d215754993e82 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,5 @@ import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional) +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Union @@ -11,7 +10,7 @@ from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, - parse_chat_message_content) + parse_chat_messages) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -92,15 +91,8 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - - for msg in request.messages: - chat_parsed_result = parse_chat_message_content( - msg, model_config, tokenizer) - - conversation.extend(chat_parsed_result.messages) - mm_futures.extend(chat_parsed_result.mm_futures) + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -115,6 +107,7 @@ async def create_chat_completion( chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) + assert isinstance(prompt, str) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c4350881a27a6..5b6b979b9b9e7 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,13 +1,11 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -# yapf conflicts with isort for this block -# yapf: disable from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (ConversationMessage, - load_chat_template, - parse_chat_message_content) +from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, ErrorResponse, @@ -17,8 +15,11 @@ # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.logger import init_logger from vllm.utils import random_uuid +logger = init_logger(__name__) + class OpenAIServingTokenization(OpenAIServing): @@ -62,12 +63,12 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation: List[ConversationMessage] = [] + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) - for message in request.messages: - result = parse_chat_message_content(message, model_config, - tokenizer) - conversation.extend(result.messages) + if mm_futures: + logger.warning( + "Multi-modal inputs are ignored during tokenization") prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt,