diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 157d873a75b4d..a0b6edd566561 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message @@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index d0c43b47bf0af..425f2a10ec855 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings["data"]) == 1 assert len(embeddings["data"][0]["embedding"]) == 3072 assert embeddings["usage"]["completion_tokens"] == 0 - assert embeddings["usage"]["prompt_tokens"] == 762 - assert embeddings["usage"]["total_tokens"] == 762 + assert embeddings["usage"]["prompt_tokens"] == 765 + assert embeddings["usage"]["total_tokens"] == 765 diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index 60a8f63eb5faa..c16192a1e1438 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -2,12 +2,10 @@ from typing import Optional import pytest -import torch -from transformers import AutoImageProcessor, AutoTokenizer +from transformers import AutoTokenizer -from vllm.inputs import InputContext, token_inputs +from vllm.inputs import InputContext, InputProcessingContext from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID -from vllm.multimodal import MultiModalRegistry from .....conftest import _ImageAssets from ....utils import build_model_context @@ -17,15 +15,9 @@ # Wrap lazy imports to avoid initializing CUDA during test collection @pytest.fixture() -def input_processor_for_phi3v(): - from vllm.model_executor.models.phi3v import input_processor_for_phi3v - return input_processor_for_phi3v - - -@pytest.fixture() -def dummy_data_for_phi3v(): - from vllm.model_executor.models.phi3v import dummy_data_for_phi3v - return dummy_data_for_phi3v +def processor_for_phi3v(): + from vllm.model_executor.models.phi3v import Phi3VProcessor + return Phi3VProcessor @pytest.fixture() @@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens(): return get_max_phi3v_image_tokens -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops", [4, 16, None]) -def test_input_mapper_override(model: str, image_assets: _ImageAssets, - num_crops: Optional[int]): - """Ensure that the [default] input mapper handles num_crops properly.""" - # We pass the processor kwargs here since for this model, we fall back to - # the default mapper; this will fall back to the HF mapper and forward - # mm_processor_kwargs to it. - mm_processor_kwargs = { - "num_crops": num_crops - } if num_crops is not None else {} - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, - ) - - hf_processor = AutoImageProcessor.from_pretrained(model, - trust_remote_code=True, - **mm_processor_kwargs) - - mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(ctx.model_config) - - image = image_assets[0].pil_image - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - - vllm_result = mm_registry.map_input( - ctx.model_config, - {"image": image}, - ) - - assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) - assert torch.all( - hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) - - # For pixel values, the second axis should be the num_crops + 1 - # for the rescaled original image. The default value in VLLM falls - # back to the HF config, which is why we compare to the processor num_crops - assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) - assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("num_crops,expected_max_tokens", [ (4, 781), @@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ - (4, 781, 1), - (4, 781, 2), - (16, 2653, 1), - (16, 2653, 2), -]) -def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, - toks_per_img: int, num_imgs: int): - """Ensure dummy_data_for_phi3v handles num_crops properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the dummy data func. - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - dummy_data = dummy_data_for_phi3v( - ctx=ctx, - seq_len=8192, # Should be bigger than num_imgs * toks_per_img - mm_counts={"image": num_imgs}, - num_crops=num_crops, - ) - sequence_data = dummy_data.seq_data - # Ensure we have the right number of placeholders per num_crops size - img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) - assert img_tok_count == toks_per_img * num_imgs - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ - (4, 757, 1), - (4, 757, 2), - (16, 1921, 1), - (16, 1921, 2), -]) -def test_input_processor_override(input_processor_for_phi3v, - image_assets: _ImageAssets, model: str, - num_crops: int, expected_toks_per_img: int, - num_imgs: int): +@pytest.mark.parametrize( + "num_crops,expected_toks_per_img,num_imgs", + [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), + # the default num_crops of phi-3.5-vision is 4 + (None, 757, 2), + (None, 757, 2), + ]) +def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, + model: str, num_crops: Optional[int], + expected_toks_per_img: int, num_imgs: int): """Ensure input_processor_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v, tokenizer_name=model, trust_remote_code=True, ) - tokenizer = AutoTokenizer.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + mm_data = {"image": images} + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs = {"num_crops": num_crops} - processed_inputs = input_processor_for_phi3v(ctx, - inputs, - num_crops=num_crops) + processor = processor_for_phi3v(ctx) + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index e6c8793989e13..d141cdf1f083b 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -15,13 +15,13 @@ # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model -MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B" # For mm_processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. -DEFAULT_NUM_CROPS = 4 -NUM_CROPS_OVERRIDE = 16 +DEFAULT_MAX_DYNAMIC_PATCH = 6 +MAX_DYNAMIC_PATCH_OVERRIDE = 4 # Mocks for all of the places that we use the mm_processor_kwargs @@ -33,10 +33,11 @@ def use_processor_mock(): def custom_processor(ctx: InputContext, inputs: DecoderOnlyInputs, *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): # For testing purposes, we don't worry about the prompt - return token_inputs(prompt_token_ids=[], - mm_processor_kwargs={"num_crops": num_crops}) + return token_inputs( + prompt_token_ids=[], + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}) with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): @@ -52,9 +53,9 @@ def custom_dummy_data_factory(self, seq_len: int, mm_counts: Mapping[str, int], *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch)) return DummyData(seq_data, None) with patch( @@ -65,15 +66,15 @@ def custom_dummy_data_factory(self, # Lazy import to avoid CUDA reinitialization error def mm_model_cls(): - from vllm.model_executor.models.phi3v import Phi3VForCausalLM + from vllm.model_executor.models.internvl import InternVLChatModel - return Phi3VForCausalLM + return InternVLChatModel # lambda whose signature matches max token calcs extra & mapper + extra kwargs -get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops -custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { - "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501 +custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501 + "pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448)) } @@ -88,27 +89,28 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): - """Get the init / inference kwargs and expected num_crops for this test.""" - # If we have a value for num_crops, pass the override value and make +def _get_max_dynamic_patch_info(init_max_dynamic_patch: int, + inference_max_dynamic_patch: int): + """Get the init / inference kwargs and expected max_dynamic_patch.""" + # If we have a value for max_dynamic_patch, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - init_kwargs = None if init_num_crops is None else { - "num_crops": init_num_crops + init_kwargs = None if init_max_dynamic_patch is None else { + "max_dynamic_patch": init_max_dynamic_patch } - inference_kwargs = None if inference_num_crops is None else { - "num_crops": inference_num_crops + inference_kwargs = None if inference_max_dynamic_patch is None else { + "max_dynamic_patch": inference_max_dynamic_patch } - if inference_num_crops is not None: - expected_seq_count = inference_num_crops - elif init_num_crops is not None: - expected_seq_count = init_num_crops + if inference_max_dynamic_patch is not None: + expected_seq_count = inference_max_dynamic_patch + elif init_max_dynamic_patch is not None: + expected_seq_count = init_max_dynamic_patch else: - expected_seq_count = DEFAULT_NUM_CROPS + expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH return init_kwargs, inference_kwargs, expected_seq_count -def _get_processed_num_crops( +def _get_processed_max_dynamic_patch( processor: Callable[[ProcessorInputs], ProcessorInputs], inference_kwargs: Optional[Dict[str, int]], ) -> int: @@ -120,27 +122,30 @@ def _get_processed_num_crops( assert "type" in processed_inputs assert processed_inputs["type"] == "token" assert "mm_processor_kwargs" in processed_inputs - return processed_inputs["mm_processor_kwargs"]["num_crops"] + return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"] -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_input_processor_kwargs(use_processor_mock, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = _get_processed_num_crops(processor, inference_kwargs) + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, inference_kwargs) - assert num_crops_val == expected_seq_count + assert max_dynamic_patch_val == expected_seq_count @pytest.mark.parametrize( @@ -165,18 +170,21 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs - num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) - assert num_crops_val == DEFAULT_NUM_CROPS + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, mm_processor_kwargs) + assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the dummy data -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch): """Ensure dummy data factories can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs) @@ -217,17 +225,20 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # len is solely dependent on the value of the mm_processor_kwargs. dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len( + dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the max token count per multimodal instance -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_max_tokens_kwarg_overrides(num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_max_tokens_kwarg_overrides(max_dynamic_patch): """Ensure max token calcs can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -239,11 +250,11 @@ def test_max_tokens_kwarg_overrides(num_crops): mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) @@ -279,26 +290,29 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) - assert max_multimodal_tokens == DEFAULT_NUM_CROPS + assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the mapper -@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) -def test_default_mapper_with_processor_kwargs(image_assets, num_crops): +@pytest.mark.parametrize( + "max_dynamic_patch", + [DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch): """Ensure that the mapper processor kwargs can fall back to HF models.""" # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - ctx = build_model_context(MULTIMODAL_MODEL_ID, - task="generate", - trust_remote_code=True, - mm_processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context( + MULTIMODAL_MODEL_ID, + task="generate", + trust_remote_code=True, + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) @@ -307,20 +321,22 @@ def test_default_mapper_with_processor_kwargs(image_assets, num_crops): mm_inputs = {"image": image} mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) - # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] - assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + # pixel vals should have shape: [batch, max_dynamic_patch+1, ...] + assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1 -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure custom mappers can use processor kwargs.""" - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -335,7 +351,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, @@ -373,11 +389,12 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) # Should filter out the inference time kwargs mapped_inputs = mm_registry.map_input( ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) - assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 + assert mapped_inputs["pixel_values"].shape[1] == ( + DEFAULT_MAX_DYNAMIC_PATCH + 1) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 646554c72481a..0dfed3b7e61bf 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -69,12 +69,12 @@ class InputProcessingContext(InputContext): tokenizer: AnyTokenizer """The tokenizer used to tokenize the inputs.""" - def get_hf_processor(self) -> ProcessorMixin: + def get_hf_processor(self, **kwargs) -> ProcessorMixin: return cached_get_processor( self.model_config.tokenizer, tokenizer=self.tokenizer, # Override the tokenizer with ours trust_remote_code=self.model_config.trust_remote_code, - ) + **kwargs) N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index eef23029a2aca..3c7854ce388ab 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -12,22 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import re -from functools import cached_property, lru_cache -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) -import numpy as np import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPVisionConfig, PretrainedConfig +from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, + ProcessorMixin) from vllm.attention import AttentionMetadata -from vllm.config import ModelConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.config import VllmConfig +from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -36,12 +32,18 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ModalityProcessingMetadata, + MultiModalDataDict, + MultiModalProcessingMetadata, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .clip import dummy_image_for_clip from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -303,231 +305,99 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 -def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): - target_height = int(np.ceil(height / padding_unit) * padding_unit) - top_padding = int((target_height - height) / 2) - bottom_padding = target_height - height - top_padding - padded_width = width - padded_height = height + top_padding + bottom_padding - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): - transposed = False - if width < height: - width, height = height, width - transposed = True - - ratio = width / height - scale = 1 - while scale * np.ceil(scale / ratio) <= hd_num: - scale += 1 - scale -= 1 - - new_width = int(scale * 336) - new_height = int(new_width / ratio) - - padded_width, padded_height = _calc_padded_size(width=new_width, - height=new_height) - - if transposed: - padded_width, padded_height = padded_height, padded_width - - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 -def get_phi3v_image_feature_size( - hf_config: Dict[str, Any], - *, - input_height: int, - input_width: int, - num_crops: int, -) -> int: - if num_crops is None: - num_crops = hf_config.get("num_crops", 16) - new_width, new_height = _calc_hd_transform_size(width=input_width, - height=input_height, - hd_num=num_crops) - - return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \ - + (new_height // 336 + 1) * 12 - - def get_max_phi3v_image_tokens(ctx: InputContext, *, num_crops: Optional[int] = None): + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs["num_crops"] = num_crops - return get_phi3v_image_feature_size( - ctx.get_hf_image_processor_config(), - input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - num_crops=num_crops, + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs, + ) + + num_tokens = image_processor.calc_num_image_tokens_from_image_size( + width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) + return num_tokens -def dummy_data_for_phi3v(ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - *, - num_crops: Optional[int] = None): +def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) - - seq_data, ranges = dummy_seq_data_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - seq_len, - num_images, - image_token_id=_IMAGE_TOKEN_ID, - image_feature_size_override=image_feature_size, - ) - mm_data = dummy_image_for_clip( + data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return DummyData(seq_data, mm_data, ranges) - + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") -@lru_cache -def _get_image_placeholder_token_id_candidates( - model_config: ModelConfig, - idx: int, -) -> List[List[int]]: - assert idx > 0 + return MultiModalKwargs(**hf_inputs) - tokenizer = cached_get_tokenizer(model_config.tokenizer) - # This is used when the image token is at the start of the string - start_candidate = tokenizer.encode(f"<|image_{idx}|>", - add_special_tokens=False) +def create_metadata_for_phi3v( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[_IMAGE_TOKEN_ID], + repl_unit=[_IMAGE_TOKEN_ID], + repl_count=get_max_phi3v_image_tokens(ctx)), + ]), + } - # This is used when the image token is in the middle of the string - # We need to get the token for "<", not "▁<" - # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json - a_token_id, = tokenizer.encode("a", add_special_tokens=False) - a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>", - add_special_tokens=False) - assert a_token_id == a_token_id_ - return [start_candidate, middle_candidate] +class Phi3VProcessor(BaseMultiModalProcessor): + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_phi3v(ctx), + ) -def input_processor_for_phi3v(ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - num_crops: Optional[int] = None): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - hf_config = ctx.get_hf_image_processor_config() - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - w, h = image_data.size - image_feature_size = [ - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops) - ] - image_data = [image_data] - elif is_list_of(image_data, Image.Image): - image_feature_size = [] - for image in image_data: - w, h = image.size - image_feature_size.append( - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops)) - elif isinstance(image_data, torch.Tensor): - image_feature_size = [image_data.shape[0]] - image_data = [image_data] - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[0] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - prompt = inputs.get("prompt") - if prompt is None: - # for async server request, we assume prompt and its token_ids is always - # in correct format. And num_image_tags == len(image_data) always True. - image_idx = range(1, len(image_data) + 1) - new_prompt = None - else: - image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) - if prompt.count("<|image|>") > 0: - logger.warning("Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating <|image|> tokens.") - elif (num_image_tags := len(image_idx)) > 1: - assert num_image_tags == len( - image_data), "The count of image_placeholder not match image's" - new_prompt = prompt - - prompt_token_ids = inputs["prompt_token_ids"].copy() - - # masked placeholder with image token id - for idx in image_idx: - candidates = _get_image_placeholder_token_id_candidates(model_config, - idx=idx) - - for candidate in candidates: - for i in range(len(prompt_token_ids) - len(candidate) + 1): - if prompt_token_ids[i:i + len(candidate)] == candidate: - prompt_token_ids[i:i + - len(candidate)] = ([_IMAGE_TOKEN_ID] * - len(candidate)) - break - - # merge consecutive tag ids - merged_token_ids: List[int] = [] - for is_placeholder, token_ids in itertools.groupby( - prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID): - if is_placeholder: - merged_token_ids.append(_IMAGE_TOKEN_ID) - else: - merged_token_ids.extend(list(token_ids)) - - # TODO: Move this to utils or integrate with clip. - new_token_ids: List[int] = [] - placeholder_ranges: List[PlaceholderRange] = [] - placeholder_idx = 0 - while merged_token_ids: - token_id = merged_token_ids.pop(0) - if token_id == _IMAGE_TOKEN_ID: - replacement_ids = repeat_and_pad_token( - _IMAGE_TOKEN_ID, - repeat_count=image_feature_size[placeholder_idx], - ) - placeholder_ranges.append({ - "offset": len(new_token_ids), - "length": len(replacement_ids) - }) - new_token_ids.extend(replacement_ids) - placeholder_idx += 1 - else: - new_token_ids.append(token_id) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) + def _get_hf_processor( + self, + *, + num_crops: Optional[int] = None, + ) -> ProcessorMixin: + if num_crops is not None: + return self.ctx.get_hf_processor(num_crops=num_crops) + return self.ctx.get_hf_processor() + + def _apply_hf_processor( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._apply_hf_processor( + prompt, mm_data, mm_processor_kwargs) + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, + # which will cause OverflowError when decoding the prompt_ids. + # Therefore, we need to do an early replacement here + token_ids = processed_outputs['input_ids'] + token_ids[token_ids < 0] = _IMAGE_TOKEN_ID + processed_outputs['input_ids'] = token_ids + return processed_outputs + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) +@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c3a95d60e6fe6..922c83b6fd8a9 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,7 +3,8 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union, cast) import torch from transformers import BatchFeature, ProcessorMixin @@ -11,7 +12,8 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of, + resolve_mm_processor_kwargs) from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, @@ -543,8 +545,14 @@ def __init__( self.ctx = ctx self.metadata = metadata + self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs + or {}) - def _get_hf_processor(self) -> ProcessorMixin: + def _get_hf_processor( + self, + **mm_processor_kwargs: Mapping[str, object], + ) -> ProcessorMixin: + # by default, we won't pass any kwargs to the processor initialization return self.ctx.get_hf_processor() def _get_tokenizer(self) -> AnyTokenizer: @@ -581,7 +589,13 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self._get_hf_processor() + # some mm_processor_kwargs may be used in processor initialization + # instead of processor call + processor_init_kwargs = { + **self.init_mm_processor_kwargs, + **mm_processor_kwargs, + } + hf_processor = self._get_hf_processor(**processor_init_kwargs) processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() @@ -601,6 +615,13 @@ def _apply_hf_processor( else: processor_data[k] = v + # filter mm_processor_kwargs used in processor call + mm_processor_kwargs = resolve_mm_processor_kwargs( + self.init_mm_processor_kwargs, + cast(Dict[str, Any], mm_processor_kwargs), + hf_processor, + ) + try: hf_inputs = hf_processor( text=prompt, # type: ignore