diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index c426e9b4ee899..ef34bebbb0f8c 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,6 +6,7 @@ from vllm import LLM, RequestOutput, SamplingParams from ...conftest import cleanup +from ..openai.test_vision import TEST_IMAGE_URLS MODEL_NAME = "facebook/opt-125m" @@ -159,3 +160,36 @@ def test_chat(): ] outputs = llm.chat(messages) assert len(outputs) == 1 + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(image_urls: List[str]): + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + dtype="bfloat16", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + trust_remote_code=True, + limit_mm_per_prompt={"image": 2}, + ) + + messages = [{ + "role": + "user", + "content": [ + *({ + "type": "image_url", + "image_url": { + "url": image_url + } + } for image_url in image_urls), + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + outputs = llm.chat(messages) + assert len(outputs) >= 0 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 53f99189beb1c..6ded5102c9314 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -1,11 +1,14 @@ import warnings +from typing import Optional import pytest from PIL import Image from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.entrypoints.chat_utils import (parse_chat_messages, + parse_chat_messages_futures) +from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -42,10 +45,28 @@ def image_url(): return f"data:image/jpeg;base64,{base64}" -@pytest.mark.asyncio -async def test_parse_chat_messages_with_image_url(phi3v_model_config, - phi3v_tokenizer, image_url): - conversation, mm_future = parse_chat_messages([{ +def _assert_mm_data_is_image_input( + mm_data: Optional[MultiModalDataDict], + image_count: int, +) -> None: + assert mm_data is not None + assert set(mm_data.keys()) == {"image"} + + image_data = mm_data.get("image") + assert image_data is not None + + if image_count == 1: + assert isinstance(image_data, Image.Image) + else: + assert isinstance(image_data, list) and len(image_data) == image_count + + +def test_parse_chat_messages_single_image( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ "role": "user", "content": [{ @@ -63,15 +84,42 @@ async def test_parse_chat_messages_with_image_url(phi3v_model_config, "role": "user", "content": "<|image_1|>\nWhat's in the image?" }] - mm_data = await mm_future - assert set(mm_data.keys()) == {"image"} - assert isinstance(mm_data["image"], Image.Image) + _assert_mm_data_is_image_input(mm_data, 1) @pytest.mark.asyncio -async def test_parse_chat_messages_multiple_images(phi3v_model_config, - phi3v_tokenizer, image_url): - conversation, mm_future = parse_chat_messages([{ +async def test_parse_chat_messages_single_image_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future = parse_chat_messages_futures([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, 1) + + +def test_parse_chat_messages_multiple_images( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ "role": "user", "content": [{ @@ -96,15 +144,49 @@ async def test_parse_chat_messages_multiple_images(phi3v_model_config, "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?" }] - mm_data = await mm_future - assert set(mm_data.keys()) == {"image"} - assert len(mm_data["image"]) == 2 + _assert_mm_data_is_image_input(mm_data, 2) @pytest.mark.asyncio -async def test_parse_chat_messages_placeholder_already_in_prompt( - phi3v_model_config, phi3v_tokenizer, image_url): - conversation, mm_future = parse_chat_messages([{ +async def test_parse_chat_messages_multiple_images_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future = parse_chat_messages_futures([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + _assert_mm_data_is_image_input(await mm_future, 2) + + +def test_parse_chat_messages_placeholder_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ "role": "user", "content": [{ @@ -131,15 +213,15 @@ async def test_parse_chat_messages_placeholder_already_in_prompt( "content": "What's in <|image_1|> and how does it compare to <|image_2|>?" }] - mm_data = await mm_future - assert set(mm_data.keys()) == {"image"} - assert len(mm_data["image"]) == 2 + _assert_mm_data_is_image_input(mm_data, 2) -@pytest.mark.asyncio -async def test_parse_chat_messages_placeholder_one_already_in_prompt( - phi3v_model_config, phi3v_tokenizer, image_url): - conversation, mm_future = parse_chat_messages([{ +def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ "role": "user", "content": [{ @@ -167,15 +249,15 @@ async def test_parse_chat_messages_placeholder_one_already_in_prompt( "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " "other one?" }] - mm_data = await mm_future - assert set(mm_data.keys()) == {"image"} - assert len(mm_data["image"]) == 2 + _assert_mm_data_is_image_input(mm_data, 2) -@pytest.mark.asyncio -async def test_parse_chat_messages_multiple_images_across_messages( - phi3v_model_config, phi3v_tokenizer, image_url): - conversation, mm_future = parse_chat_messages([{ +def test_parse_chat_messages_multiple_images_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ "role": "user", "content": [{ @@ -218,14 +300,14 @@ async def test_parse_chat_messages_multiple_images_across_messages( "content": "<|image_2|>\nWhat about this one?" }, ] - mm_data = await mm_future - assert set(mm_data.keys()) == {"image"} - assert len(mm_data["image"]) == 2 + _assert_mm_data_is_image_input(mm_data, 2) -@pytest.mark.asyncio -async def test_parse_chat_messages_rejects_too_many_images_in_one_message( - phi3v_model_config, phi3v_tokenizer, image_url): +def test_parse_chat_messages_rejects_too_many_images_in_one_message( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -259,9 +341,11 @@ async def test_parse_chat_messages_rejects_too_many_images_in_one_message( }], phi3v_model_config, phi3v_tokenizer) -@pytest.mark.asyncio -async def test_parse_chat_messages_rejects_too_many_images_across_messages( - phi3v_model_config, phi3v_tokenizer, image_url): +def test_parse_chat_messages_rejects_too_many_images_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c70c6d9330b10..f205a99920892 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,10 +1,11 @@ import asyncio import codecs +from abc import ABC, abstractmethod from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping, - Optional, Tuple, Union) +from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, + Mapping, Optional, Tuple, TypeVar, Union) # yapf conflicts with isort for this block # yapf: disable @@ -23,7 +24,8 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import (async_get_and_parse_audio, - async_get_and_parse_image) + async_get_and_parse_image, + get_and_parse_audio, get_and_parse_image) from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -81,7 +83,11 @@ class ConversationMessage(TypedDict): content: str -class MultiModalItemTracker: +ModalityStr = Literal["image", "audio"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number of multi-modal items in a given request does not exceed the configured @@ -89,37 +95,28 @@ class MultiModalItemTracker: """ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + self._model_config = model_config self._tokenizer = tokenizer self._allowed_items = (model_config.multimodal_config.limit_per_prompt if model_config.multimodal_config else {}) self._consumed_items = {k: 0 for k in self._allowed_items} - self._futures: List[Awaitable[MultiModalDataDict]] = [] + + self._items: List[_T] = [] @staticmethod @lru_cache(maxsize=None) - def _cached_token_str(tokenizer: AnyTokenizer, token_index: int): + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: return tokenizer.decode(token_index) - def add(self, modality: Literal["image", "audio"], - mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]: - """ - Adds the multi-modal item to the current prompt and returns the - placeholder string to use, if any. - """ - allowed_count = self._allowed_items.get(modality, 1) - current_count = self._consumed_items.get(modality, 0) + 1 - if current_count > allowed_count: - raise ValueError( - f"At most {allowed_count} {modality}(s) may be provided in " - "one request.") - - self._consumed_items[modality] = current_count - self._futures.append(mm_future) - + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) - model_type = self._model_config.hf_config.model_type + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + if modality == "image": if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer @@ -130,9 +127,8 @@ def add(self, modality: Literal["image", "audio"], # These models do not use image tokens in the prompt return None if model_type.startswith("llava"): - return MultiModalItemTracker._cached_token_str( - self._tokenizer, - self._model_config.hf_config.image_token_index) + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" @@ -145,11 +141,11 @@ def add(self, modality: Literal["image", "audio"], raise TypeError(f"Unknown modality: {modality}") @staticmethod - async def _combine(futures: List[Awaitable[MultiModalDataDict]]): + def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: mm_lists: Mapping[str, List[object]] = defaultdict(list) # Merge all the multi-modal items - for single_mm_data in (await asyncio.gather(*futures)): + for single_mm_data in items: for mm_key, mm_item in single_mm_data.items(): if isinstance(mm_item, list): mm_lists[mm_key].extend(mm_item) @@ -162,9 +158,113 @@ async def _combine(futures: List[Awaitable[MultiModalDataDict]]): for mm_key, mm_list in mm_lists.items() } - def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]: - return MultiModalItemTracker._combine( - self._futures) if self._futures else None + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._items.append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + return self._combine(self._items) if self._items else None + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker( + BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items: + items = await asyncio.gather(*self._items) + return self._combine(items) + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> Dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image = get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image_coro = async_get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = async_get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) def load_chat_template( @@ -197,10 +297,10 @@ def load_chat_template( # (similar to chat template) def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], text_prompt: str) -> str: - """Combine multimodal prompts for a multimodal language model""" + """Combine multimodal prompts for a multimodal language model.""" # Look through the text prompt to check for missing placeholders - missing_placeholders = [] + missing_placeholders: List[str] = [] for placeholder in placeholder_counts: # For any existing placeholder in the text prompt, we leave it as is @@ -227,12 +327,11 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], - mm_tracker: MultiModalItemTracker, + mm_tracker: BaseMultiModalItemTracker, ) -> List[ConversationMessage]: texts: List[str] = [] - # multimodal placeholder_string : count - mm_placeholder_counts: Dict[str, int] = {} + mm_parser = mm_tracker.create_parser() for part in parts: part_type = part["type"] @@ -247,22 +346,16 @@ def _parse_chat_message_content_parts( "'image_url.detail' is currently not supported and " "will be ignored.") - image_coro = async_get_and_parse_image(image_url["url"]) - placeholder = mm_tracker.add("image", image_coro) - if placeholder: - mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( - placeholder, 0) + 1 + mm_parser.parse_image(image_url["url"]) elif part_type == "audio_url": audio_url = _AudioParser.validate_python(part)["audio_url"] - audio_coro = async_get_and_parse_audio(audio_url["url"]) - placeholder = mm_tracker.add("audio", audio_coro) - if placeholder: - mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( - placeholder, 0) + 1 + + mm_parser.parse_audio(audio_url["url"]) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() if mm_placeholder_counts: text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, text_prompt) @@ -271,8 +364,9 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content( - message: ChatCompletionMessageParam, - mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]: + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -292,7 +386,7 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, -) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]: +) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -304,6 +398,22 @@ def parse_chat_messages( return conversation, mm_tracker.all_mm_data() +def parse_chat_messages_futures( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + return conversation, mm_tracker.all_mm_data() + + def apply_chat_template( tokenizer: AnyTokenizer, conversation: List[ConversationMessage], diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0edd4bfaecd6a..b32c90a4df1aa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,7 +23,7 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs +from vllm.utils import Counter, deprecate_kwargs, is_list_of logger = init_logger(__name__) @@ -358,15 +358,18 @@ def chat( add_generation_prompt: bool = True, ) -> List[RequestOutput]: """ - Generates responses for chat messages. + Generate responses for a chat conversation. - Converts the messages to prompts using the tokenizer and calls - the :meth:`generate` method to generate the responses. + The chat conversation is converted into a text prompt using the + tokenizer and calls the :meth:`generate` method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. Args: - messages: A list of messages to generate responses for. Each - message is a list of dictionaries with 'role' and 'content' - keys. + messages: A single conversation represented as a list of messages. + Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -387,21 +390,25 @@ def chat( tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() - conversations, _ = parse_chat_messages(messages, model_config, - tokenizer) + conversation, mm_data = parse_chat_messages(messages, model_config, + tokenizer) prompt = apply_chat_template( tokenizer, - conversations, + conversation, chat_template=chat_template, - add_generation_prompt=add_generation_prompt) + add_generation_prompt=add_generation_prompt, + ) inputs: PromptInputs - if isinstance(prompt, list) and isinstance(prompt[0], int): + if is_list_of(prompt, int): inputs = TokensPrompt(prompt_token_ids=prompt) else: inputs = TextPrompt(prompt=prompt) + if mm_data is not None: + inputs["multi_modal_data"] = mm_data + return self.generate( inputs, sampling_params=sampling_params, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f7576509d06c8..a3bc0bb7b3554 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import (ConversationMessage, apply_chat_template, load_chat_template, - parse_chat_messages) + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -26,7 +26,6 @@ TextTokensPrompt) from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -94,7 +93,7 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation, mm_data_future = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ @@ -114,10 +113,8 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) - mm_data: Optional[MultiModalDataDict] = None try: - if mm_data_future: - mm_data = await mm_data_future + mm_data = await mm_data_future except Exception as e: logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index fc9ca29e9cf86..c3c0d52072cd3 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -4,7 +4,7 @@ from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (apply_chat_template, load_chat_template, - parse_chat_messages) + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -65,10 +65,11 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation, mm_data_future = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) - if mm_data_future: + mm_data = await mm_data_future + if mm_data: logger.warning( "Multi-modal inputs are ignored during tokenization") diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 4bed267e99637..b76b765bc677a 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -120,6 +120,16 @@ async def async_fetch_audio( return librosa.load(BytesIO(audio_bytes), sr=None) +def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: + audio, sr = fetch_audio(audio_url) + return {"audio": (audio, sr)} + + +def get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = fetch_image(image_url) + return {"image": image} + + async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: audio, sr = await async_fetch_audio(audio_url) return {"audio": (audio, sr)} diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 23ecfc0af6be4..533a86b787325 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -52,12 +52,13 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None: assert isinstance(self.tokenizer, (Tekkenizer, SentencePieceTokenizer)), type( self.tokenizer) - self._is_tekken = isinstance(self.tokenizer, Tekkenizer) - if self._is_tekken: + if (is_tekken := isinstance(self.tokenizer, Tekkenizer)): # Make sure special tokens will not raise self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE + self._is_tekken = is_tekken + # the following attributes are set to fit VLLM's design self.is_fast = True self.chat_template = True