From d5be8c4f71b060e65edb2fb7a9f5b8657c9ec22a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Dec 2024 16:50:58 +0800 Subject: [PATCH] [3/N] Support and implement merged input processor for LLaVA model (#10676) Signed-off-by: DarkLight1337 Co-authored-by: Roger Wang --- tests/multimodal/test_mapper.py | 49 +-- tests/multimodal/test_processing.py | 277 +++++++++++----- .../vllm_add_dummy_model/my_llava.py | 12 +- vllm/inputs/registry.py | 42 ++- vllm/model_executor/models/llava.py | 219 +++++------- vllm/multimodal/base.py | 51 ++- vllm/multimodal/processing.py | 313 +++++++++++------- vllm/multimodal/registry.py | 67 +++- vllm/v1/engine/mm_input_mapper.py | 1 + vllm/v1/engine/processor.py | 16 +- 10 files changed, 626 insertions(+), 421 deletions(-) diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 13ad4a7966b9d..71832acbd17b8 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from transformers import CLIPImageProcessor, LlavaNextImageProcessor +from transformers import LlavaNextImageProcessor from vllm.config import ModelConfig from vllm.multimodal import MultiModalRegistry @@ -14,49 +14,6 @@ def mm_registry(): return MultiModalRegistry() -@pytest.mark.parametrize("dtype", ["half", "float"]) -@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - - hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) - assert isinstance(hf_processor, CLIPImageProcessor) - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - limit_mm_per_prompt={"image": 1}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - for asset in image_assets: - image = rescale_image_size(asset.pil_image, size_factor) - - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - vllm_result = mm_registry.map_input( - model_config, - {"image": image}, - ) - - assert hf_result.keys() == vllm_result.keys() - for key, hf_tensor in hf_result.items(): - hf_arr: np.ndarray = hf_tensor.numpy() - vllm_arr: np.ndarray = vllm_result[key].numpy() - - assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" - assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - @pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) def test_llava_next_image_processor(image_assets, mm_registry, dtype, @@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, (2, 1, False), (2, 2, True)], ) def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, @@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): # NOTE: We don't test zero images since the HF processor doesn't support it @pytest.mark.parametrize("num_images", [1, 2]) def test_image_mapper_multi(image_assets, mm_registry, num_images): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b2367060c6c1b..ae668d1dd56c8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,50 +3,15 @@ import pytest from transformers import BatchFeature -from vllm.multimodal.processing import (PromptReplacement, find_text_matches, - find_token_matches, iter_token_matches, - iter_token_runs, replace_text_matches) +from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, + find_text_matches, find_token_matches, + iter_placeholders, iter_token_matches, + replace_text_matches, + replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby -# yapf: disable -@pytest.mark.parametrize( - ("token_ids", "expected"), - [ - ([], []), - ( - [32000, 32000, 32000], - [{ "token_id": 32000, "start_idx": 0, "length": 3 }], - ), - ( - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], - [ - { "token_id": 9833, "start_idx": 0, "length": 1 }, - { "token_id": 28747, "start_idx": 1, "length": 1 }, - { "token_id": 32000, "start_idx": 2, "length": 3 }, - { "token_id": 9833, "start_idx": 5, "length": 1 }, - { "token_id": 28747, "start_idx": 6, "length": 1 }, - { "token_id": 32000, "start_idx": 7, "length": 2 }, - { "token_id": 918, "start_idx": 9, "length": 1 }, - ], - ), - ], -) -# yapf: enable -def test_iter_token_runs(token_ids, expected): - result = list(iter_token_runs(token_ids)) - - # Only displayed on error - print("result:", result) - - # Manually constructed results - assert [item._asdict() for item in result] == expected - - # Invariants - assert sum(run_info.length for run_info in result) == len(token_ids) - - # yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), @@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - result = find_token_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_token_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - result = find_text_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_text_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + ("prompt", "target_by_key", "repl_by_key"), [ ( "Image:Image:!", @@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test multiple repl_count "pattern_3": ("?", 2), }, - { - # Test no replacement - 0: "Image:Image:!", - # Test single replacement - 1: "Image:??", - # Test repeated replacement - 2: "??", - }, ), ] ) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, "Image:Image:!"), + (1, "Image:??"), + (2, "??"), + ] +) # yapf: enable def test_find_replace_text( prompt, target_by_key, repl_by_key, - expected_by_mm_count, + mm_count, + expected, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - matches = find_text_matches( + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_text_matches(prompt, prompt_repls) + + result = replace_text_matches( prompt, - [ - PromptReplacement(target, *repl_by_key[key]) \ - .bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), ) - result_by_mm_count = { - mm_count: replace_text_matches( - prompt, - matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), - ) - for mm_count in expected_by_mm_count - } # Only displayed on error print("matches:", matches) - print("result_by_mm_count:", result_by_mm_count) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key"), + [ + # Tokenized test cases of `test_find_replace_text` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": [32000], + "pattern_2": [9833, 28747], + "pattern_3": [918], + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ([32000, 32000], 1), + # Test empty repl_unit + "pattern_2": ([], 1), + # Test multiple repl_count + "pattern_3": ([1550], 2), + }, + ), + ] +) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), + (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), + (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), + ] +) +# yapf: enable +def test_find_replace_tokens( + prompt, + target_by_key, + repl_by_key, + mm_count, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_token_matches(prompt, prompt_repls) + + result = replace_token_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + + # Only displayed on error + print("matches:", matches) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + "repl_by_key", + [ + { + "pattern_1": ([32000, 32000], 1), + "pattern_2": ([], 1), + "pattern_3": ([1550], 2), + }, + ], +) +@pytest.mark.parametrize( + ("prompt", "expected"), + [ + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=6, + unit=[32000, 32000], + unit_count=1, + ), + ], + ), + ( + [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_1", + start_idx=5, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=7, + unit=[1550], + unit_count=2, + ), + ], + ), + ( + [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=2, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=6, + unit=[1550], + unit_count=2, + ), + ], + ), + ] +) +def test_iter_placeholders( + repl_by_key, + prompt, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement([], *repl).bind(key, mock_tokenizer) + for key, repl in repl_by_key.items() + ] + + result = list(iter_placeholders(prompt_repls, prompt)) + + # Only displayed on error + print("result:", result) # Manually constructed results - assert result_by_mm_count == expected_by_mm_count + assert result == expected diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 3ebd7864b8fc8..f2fc0755cae01 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,19 +2,17 @@ import torch -from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - dummy_data_for_llava, - get_max_llava_image_tokens, - input_processor_for_llava) + create_metadata_for_llava, + dummy_mm_kwargs_for_llava, + get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava, + dummy_mm_kwargs_for_llava) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 85ab4355cc2e4..646554c72481a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -232,19 +232,35 @@ def dummy_data_for_profiling( """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.utils import cached_get_tokenizer + + if mm_registry.has_processor(model_config): + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + processor = mm_registry.create_processor(model_config, tokenizer) + + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_max_tokens = mm_registry.get_max_tokens_by_modality( + model_config) + + dummy_data = processor.get_dummy_data(seq_len, mm_counts, + mm_max_tokens) else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, overrides=model_config.mm_processor_kwargs) + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = dummy_data.seq_data.prompt_token_ids @@ -257,7 +273,9 @@ def dummy_data_for_profiling( raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if dummy_data.multi_modal_data is not None: + + if (dummy_data.multi_modal_data is not None and + not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d375c1c9da2a9..953b89f1842af 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,17 +1,19 @@ from functools import cached_property +from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn -from PIL import Image -from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, - PretrainedConfig, SiglipVisionConfig) +from PIL.Image import Image +from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, + PixtralVisionConfig, PretrainedConfig, + ProcessorMixin, SiglipVisionConfig) +from transformers.models.pixtral import PixtralProcessor from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -19,21 +21,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (InputProcessingContext, + ModalityProcessingMetadata, + MultiModalProcessingMetadata, + MultiModalProcessor, PromptReplacement) from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_max_clip_image_tokens, - input_processor_for_clip) + get_max_clip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - dummy_seq_data_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - input_processor_for_pixtral_hf) + get_max_pixtral_hf_image_tokens) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens, - input_processor_for_siglip) + get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_data_for_llava(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config num_images = mm_counts["image"] - image_feature_size = get_max_llava_image_tokens(ctx) - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_clip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_clip(vision_config, num_images) elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_siglip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_siglip(vision_config, num_images) elif isinstance(vision_config, PixtralVisionConfig): - seq_data, ranges = dummy_seq_data_for_pixtral_hf( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) -def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - model_config = ctx.model_config +def create_metadata_for_llava( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config + image_token_id = hf_config.image_token_index + + def get_repl_count( + mm_items: list[Image], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + return get_max_llava_image_tokens(ctx) + + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[image_token_id], + repl_unit=[image_token_id], + repl_count=get_repl_count), + ]), + } - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_max_llava_image_tokens(ctx) - elif is_list_of(image_data, Image.Image): - image_feature_size = [get_max_llava_image_tokens(ctx) - ] * len(image_data) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, PixtralVisionConfig): - # We ignore image_feature_size_override since we have non-uniform - # image sizes for Pixtral - return input_processor_for_pixtral_hf( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - ) +class LlavaProcessor(MultiModalProcessor): - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): + if getattr(hf_processor, "__is_patched__", False): + return # Already patched + + image_processor = hf_processor.image_processor # type: ignore + orig_preprocess = image_processor.preprocess + + def preprocess(__self, *args, **kwargs): + hf_inputs = orig_preprocess(*args, **kwargs) + hf_inputs["is_pixtral"] = torch.tensor(True) + return hf_inputs + + image_processor.preprocess = MethodType(preprocess, image_processor) + + hf_processor.__is_patched__ = True # type: ignore + + def _get_hf_processor(self) -> ProcessorMixin: + hf_processor = self.ctx.get_hf_processor() + + if isinstance(hf_processor, PixtralProcessor): + self._patch_pixtral_processor(hf_processor) + + return hf_processor + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_llava(self.ctx, mm_counts) class LlavaLikeConfig(Protocol): @@ -291,10 +276,11 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), +)) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -367,38 +353,10 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: return data - def _validate_image_sizes(self, images: List[torch.Tensor], - sizes: List[torch.Tensor]) -> List[torch.Tensor]: - if not isinstance(sizes, list): - sizes = [sizes] - - total_images = sum(size.numel() // 2 for size in sizes) - if total_images != len(images): - raise ValueError("Mismatch in number of images. " - f"Expected {total_images}, got {len(images)}") - img_idx = 0 - for size in sizes: - # Flatten the size tensor to a list of (height, width) pairs - size = size.view(-1, 2).tolist() - for expected_h, expected_w in size: - if img_idx >= len(images): - raise ValueError("Ran out of images before sizes. " - f"{img_idx} >= {len(images)}") - img = images[img_idx] - if img.shape[-2:] != (expected_h, expected_w): - raise ValueError( - "Image size mismatch. Expected " - f"{(expected_h, expected_w)}, got {img.shape[-2:]}") - if img.shape[-3] != 3: - raise ValueError("Image channel mismatch. Expected 3, " - f"got {img.shape[-3]}") - img_idx += 1 - return images - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - image_sizes = kwargs.pop("image_sizes", None) + is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -409,9 +367,8 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Case for models like PixtralHF that have dynamic image sizes - # so we need to produce a list of tensors - if image_sizes is not None: + assert isinstance(is_pixtral, torch.Tensor) + if is_pixtral.any(): images = pixel_values def flatten_to_3d_tensors(item): @@ -434,7 +391,7 @@ def flatten_to_3d_tensors(item): return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_sizes(images, image_sizes), + data=images, ) return LlavaImagePixelInputs( diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f93722523728d..7dba94b885b6d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -226,16 +226,16 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture + from vllm.model_executor.models import supports_multimodal model_cls, _ = get_model_architecture(model_config) - if model_cls not in self._input_mappers: + if not supports_multimodal(model_cls): return 0 max_mm_tokens = self._max_mm_tokens.get(model_cls) if max_mm_tokens is None: - raise KeyError(f"No maximum number of multi-modal tokens is given " - f"for model class {model_cls.__name__} in {self}.") + return 0 if callable(max_mm_tokens): mm_processor_kwargs = get_allowed_kwarg_only_overrides( @@ -326,26 +326,47 @@ def from_seq_group( src_ranges = [] dest_ranges = [] """ - if (not seq_group.multi_modal_data - or not seq_group.multi_modal_placeholders): - return seq_group.multi_modal_data, {} + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders + + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps - mm_data = {**seq_group.multi_modal_data} - placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( MultiModalPlaceholderMap) - for ( - modality, - placeholders, - ) in seq_group.multi_modal_placeholders.items(): + for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) if not isinstance(mm_items, list): mm_items = [mm_items] if positions: - intersecting_items = placeholder_maps[ - modality].append_items_from_seq_group( - positions, mm_items, placeholders) + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) if intersecting_items: mm_data[modality] = intersecting_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 28c8dda581982..4a1737991534f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,14 +3,13 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from itertools import groupby from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union -import numpy as np -from transformers import BatchFeature +import torch +from transformers import BatchFeature, ProcessorMixin from typing_extensions import TypeAlias, TypedDict -from vllm.inputs import InputProcessingContext +from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import flatten_2d_lists, full_groupby, is_list_of @@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return multi_data -class _TokenRun(NamedTuple): - token_id: int - - start_idx: int - length: int - - -def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: - """ - Yield the starting index and length of each run of tokens that are the same. - """ - start_idx = 0 - - for token_id, it in groupby(token_ids): - length = sum(1 for _ in it) - yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) - - start_idx += length - - -class _PlaceholderInfo(NamedTuple): - modality: str - offset: int - length: int - - def to_range(self) -> PlaceholderRange: - return PlaceholderRange(offset=self.offset, length=self.length) - - -def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement[Any]], - token_ids: list[int], - *, - min_placeholder_count: int, -) -> Iterable[_PlaceholderInfo]: - """Yield each set of placeholder tokens found in :code:`token_ids`.""" - placeholder_ids_by_modality = { - modality: { - token_id - for prompt_repl in repls - for token_id in prompt_repl.repl_unit.token_ids - } - for modality, repls in full_groupby_modality(prompt_repls) - } - - for run_info in iter_token_runs(token_ids): - if run_info.length > min_placeholder_count: - for (modality, - placeholder_ids) in placeholder_ids_by_modality.items(): - if run_info.token_id in placeholder_ids: - yield _PlaceholderInfo( - modality=modality, - offset=run_info.start_idx, - length=run_info.length, - ) - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -353,13 +295,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: raise NotImplementedError + @property @abstractmethod - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> _S: + def repl_unit(self) -> _S: raise NotImplementedError def __repr__(self) -> str: @@ -380,15 +318,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end_idx - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> list[int]: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.token_ids * count + @property + def repl_unit(self) -> list[int]: + return self.prompt_repl.repl_unit.token_ids @dataclass(repr=False) @@ -404,15 +336,26 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end() - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> str: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.text * count + @property + def repl_unit(self) -> str: + return self.prompt_repl.repl_unit.text + + +class _PlaceholderInfo(NamedTuple): + modality: str + start_idx: int + unit: list[int] + unit_count: int + + @property + def length(self) -> int: + return len(self.unit) * self.unit_count + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange( + offset=self.start_idx, + length=self.length, + ) def find_token_matches( @@ -447,15 +390,17 @@ def _resolve_matches( Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - num_matches_by_idx = np.zeros(len(prompt), dtype=int) + seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ + = [None] * len(prompt) + for match in matches: - num_matches_by_idx[match.start_idx:match.end_idx] += 1 + for idx in range(match.start_idx, match.end_idx): + if seen_matches[idx] is not None: + raise ValueError("Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}") - duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) - if len(duplicate_matches_idxs) > 0: - raise ValueError("Unable to find a unique replacement " - f"at indices={duplicate_matches_idxs} " - f"of prompt={prompt}") + seen_matches[idx] = match return sorted(matches, key=lambda x: x.start_idx) @@ -480,9 +425,12 @@ def _replace_matches( start_idx = match.start_idx end_idx = match.end_idx - repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + repl_unit = match.repl_unit + repl_info = match.prompt_repl + repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + out_seqs.append(prompt[prev_end_idx:start_idx] + + repl_unit * repl_count) prev_end_idx = end_idx next_idx_by_modality[modality] += 1 @@ -531,7 +479,57 @@ def replace_text_matches( return "".join(texts) -class MultiModalProcessor: +def _merge_placeholder_matches( + matches: Iterable[_PromptReplacementTokenMatch], +) -> Iterable[_PromptReplacementTokenMatch]: + current_match = None + + for match in sorted(matches, key=lambda x: x.start_idx): + if current_match is None: + current_match = match + elif (current_match.prompt_repl == match.prompt_repl + and current_match.end_idx == match.start_idx): + current_match = _PromptReplacementTokenMatch( + current_match.prompt_repl, + match=_TokenMatch(current_match.start_idx, match.end_idx), + ) + else: + yield current_match + current_match = match + + if current_match is not None: + yield current_match + + +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt: list[int], + *, + min_unit_count: int = 1, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + if min_unit_count <= 0: + raise ValueError("`min_unit_count` must be a positive integer") + + matches = (_PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 + for match in iter_token_matches(prompt, repl_unit)) + + for match in _merge_placeholder_matches(matches): + unit = match.repl_unit + placeholder = _PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=unit, + unit_count=(match.end_idx - match.start_idx) // len(unit), + ) + + if placeholder.unit_count >= min_unit_count: + yield placeholder + + +class MultiModalProcessor(ABC): """ Helper class to process multi-modal inputs to be used in vLLM. """ @@ -546,6 +544,12 @@ def __init__( self.ctx = ctx self.metadata = metadata + def _get_hf_processor(self) -> ProcessorMixin: + return self.ctx.get_hf_processor() + + def _get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + def __call__( self, prompt: str, @@ -562,13 +566,13 @@ def _find_placeholders( # To avoid false positives from multi-input when detecting # whether placeholder tokens have been inserted, in case # the target sequence is a subset of the replacement tokens - min_placeholder_count: int = 16, + min_unit_count: int = 16, ) -> list[_PlaceholderInfo]: return list( iter_placeholders( all_prompt_repls, new_token_ids, - min_placeholder_count=min_placeholder_count, + min_unit_count=min_unit_count, )) def _apply_hf_processor( @@ -577,19 +581,49 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self.ctx.get_hf_processor() + hf_processor = self._get_hf_processor() + + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + for k, v in mm_data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + if k in ("image", "video", "audio"): + if isinstance(v, torch.Tensor) and v.ndim == 3: + # Pass through embedding inputs (single) + passthrough_data[f"{k}_embeds"] = [v] + elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: + # Pass through embedding inputs (multi) + passthrough_data[f"{k}_embeds"] = v + else: + # Map keys to plural form, e.g.: image -> images + processor_data[f"{k}s"] = v + else: + processor_data[k] = v + + try: + hf_inputs = hf_processor( + text=prompt, # type: ignore + **processor_data, + **mm_processor_kwargs, + return_tensors="pt", + ) + except Exception as exc: + data = dict(text=prompt, **processor_data) - return hf_processor( - text=prompt, # type: ignore - **mm_data, - **mm_processor_kwargs, - ) + raise RuntimeError( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={mm_processor_kwargs}") from exc + + hf_inputs.update(passthrough_data) + + return hf_inputs def _bind_prompt_replacements( self, mm_data: MultiModalDataDict, ) -> list[_BoundPromptReplacement[Any]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() return [ prompt_repl.bind(modality, tokenizer) @@ -604,7 +638,7 @@ def _apply_prompt_replacements( token_ids: list[int], prompt_repls: Sequence[_BoundPromptReplacement[Any]], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() mm_items = to_multi_format(mm_data) token_matches = find_token_matches(token_ids, prompt_repls) @@ -620,7 +654,7 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= len(mm_data[modality]) + len(matches) >= len(mm_items[modality]) for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( @@ -648,15 +682,6 @@ def _apply_prompt_replacements( placeholders = self._find_placeholders(matched_repls, token_ids) - # Sanity check - assert len(placeholders) == len(matched_repls), dict( - # Log this information for easier debugging - text=text, - token_ids=token_ids, - placeholders=placeholders, - matched_repls=matched_repls, - ) - return token_ids, text, placeholders def apply( @@ -678,7 +703,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() hf_inputs = self._apply_hf_processor(prompt_text, mm_data, mm_processor_kwargs) @@ -717,3 +742,59 @@ def apply( mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) + + @abstractmethod + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + """ + Build the input that corresponds to `mm_max_tokens` in + :meth:`get_dummy_data`. + """ + raise NotImplementedError + + def get_dummy_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_max_tokens: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + tokenizer = self._get_tokenizer() + + mm_placeholders = dict[str, _PlaceholderInfo]() + offset = 0 + + for modality, max_tokens in mm_max_tokens.items(): + if max_tokens == 0: + continue + + metadata = self.metadata[modality] + repl = metadata.prompt_repls[0].bind(modality, tokenizer) + repl_token_ids = repl.repl_unit.token_ids + + placeholders = _PlaceholderInfo( + modality=modality, + start_idx=offset, + unit=repl_token_ids, + unit_count=max_tokens // len(repl_token_ids), + ) + + mm_placeholders[modality] = placeholders + offset += placeholders.length + + prompt_token_ids = flatten_2d_lists( + [p.unit * p.unit_count for p in mm_placeholders.values()]) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=self._get_dummy_mm_kwargs(mm_counts), + multi_modal_placeholders={ + modality: [p.to_range()] + for modality, p in mm_placeholders.items() + }, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b73daee98bd80..f51da8972d15b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,7 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessor +from .processing import MultiModalProcessingMetadata, MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -200,9 +200,12 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + def get_max_tokens_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: """ - Get the maximum number of multi-modal tokens + Get the maximum number of tokens from each modality for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. @@ -212,9 +215,23 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ limits_per_plugin = self._limits_by_model[model_config] - return sum((limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items()) + return { + key: (limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items() + } + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return sum(self.get_max_tokens_by_modality(model_config).values()) def init_mm_limits_per_prompt( self, @@ -270,7 +287,8 @@ def register_processor( factory: MultiModalProcessorFactory, ): """ - Register a multi-modal processor to a model class. + Register a multi-modal processor to a model class. The processor + is constructed lazily, hence a factory method should be passed. When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs. @@ -293,6 +311,41 @@ def wrapper(model_cls: N) -> N: return wrapper + def register_processor_by_metadata( + self, + metadata_factory: Callable[[InputProcessingContext], + MultiModalProcessingMetadata], + get_dummy_mm_kwargs: Callable[ + [InputProcessingContext, Mapping[str, int]], MultiModalKwargs], + ): + """ + Convenience method to register a multi-modal processor to a model class + according to a function that constructs its metadata. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + class ConcreteMultiModalProcessor(MultiModalProcessor): + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return get_dummy_mm_kwargs(self.ctx, mm_counts) + + def factory(ctx: InputProcessingContext): + return ConcreteMultiModalProcessor( + ctx=ctx, + metadata=metadata_factory(ctx), + ) + + return self.register_processor(factory) + def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 594c973678235..45882f8f076d4 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -12,6 +12,7 @@ def __init__( model_config: ModelConfig, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): + self.model_config = model_config self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry.create_input_mapper( model_config) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7a1ea2530abda..120fc64969552 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -7,7 +7,8 @@ from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -101,10 +102,15 @@ def process_inputs( self.generation_config_fields, eos_token_id) # Preprocess multi-modal data - mm_inputs = self.mm_input_mapper.process_inputs( - decoder_inputs.multi_modal_data, - decoder_inputs.mm_processor_kwargs) if len( - decoder_inputs.multi_modal_data) > 0 else None + if len(decoder_inputs.multi_modal_data) == 0: + mm_inputs = None + elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): + mm_inputs = [decoder_inputs.multi_modal_data] + else: + mm_inputs = self.mm_input_mapper.process_inputs( + decoder_inputs.multi_modal_data, + decoder_inputs.mm_processor_kwargs, + ) # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest(