From 7b6c4f1e46d2177a556ec1b824de60f65e576062 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 16:22:26 +0000 Subject: [PATCH 01/21] Add `get_dummy_data` to `MultiModalProcessor`; fix and test `iter_placeholders` Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 277 ++++++++++++++++++++-------- vllm/inputs/registry.py | 45 +++-- vllm/multimodal/processing.py | 239 ++++++++++++++---------- vllm/multimodal/registry.py | 30 ++- 4 files changed, 393 insertions(+), 198 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b2367060c6c1b..ae668d1dd56c8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,50 +3,15 @@ import pytest from transformers import BatchFeature -from vllm.multimodal.processing import (PromptReplacement, find_text_matches, - find_token_matches, iter_token_matches, - iter_token_runs, replace_text_matches) +from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, + find_text_matches, find_token_matches, + iter_placeholders, iter_token_matches, + replace_text_matches, + replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby -# yapf: disable -@pytest.mark.parametrize( - ("token_ids", "expected"), - [ - ([], []), - ( - [32000, 32000, 32000], - [{ "token_id": 32000, "start_idx": 0, "length": 3 }], - ), - ( - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], - [ - { "token_id": 9833, "start_idx": 0, "length": 1 }, - { "token_id": 28747, "start_idx": 1, "length": 1 }, - { "token_id": 32000, "start_idx": 2, "length": 3 }, - { "token_id": 9833, "start_idx": 5, "length": 1 }, - { "token_id": 28747, "start_idx": 6, "length": 1 }, - { "token_id": 32000, "start_idx": 7, "length": 2 }, - { "token_id": 918, "start_idx": 9, "length": 1 }, - ], - ), - ], -) -# yapf: enable -def test_iter_token_runs(token_ids, expected): - result = list(iter_token_runs(token_ids)) - - # Only displayed on error - print("result:", result) - - # Manually constructed results - assert [item._asdict() for item in result] == expected - - # Invariants - assert sum(run_info.length for run_info in result) == len(token_ids) - - # yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), @@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - result = find_token_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_token_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - result = find_text_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_text_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + ("prompt", "target_by_key", "repl_by_key"), [ ( "Image:Image:!", @@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test multiple repl_count "pattern_3": ("?", 2), }, - { - # Test no replacement - 0: "Image:Image:!", - # Test single replacement - 1: "Image:??", - # Test repeated replacement - 2: "??", - }, ), ] ) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, "Image:Image:!"), + (1, "Image:??"), + (2, "??"), + ] +) # yapf: enable def test_find_replace_text( prompt, target_by_key, repl_by_key, - expected_by_mm_count, + mm_count, + expected, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - matches = find_text_matches( + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_text_matches(prompt, prompt_repls) + + result = replace_text_matches( prompt, - [ - PromptReplacement(target, *repl_by_key[key]) \ - .bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), ) - result_by_mm_count = { - mm_count: replace_text_matches( - prompt, - matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), - ) - for mm_count in expected_by_mm_count - } # Only displayed on error print("matches:", matches) - print("result_by_mm_count:", result_by_mm_count) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key"), + [ + # Tokenized test cases of `test_find_replace_text` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": [32000], + "pattern_2": [9833, 28747], + "pattern_3": [918], + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ([32000, 32000], 1), + # Test empty repl_unit + "pattern_2": ([], 1), + # Test multiple repl_count + "pattern_3": ([1550], 2), + }, + ), + ] +) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), + (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), + (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), + ] +) +# yapf: enable +def test_find_replace_tokens( + prompt, + target_by_key, + repl_by_key, + mm_count, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_token_matches(prompt, prompt_repls) + + result = replace_token_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + + # Only displayed on error + print("matches:", matches) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + "repl_by_key", + [ + { + "pattern_1": ([32000, 32000], 1), + "pattern_2": ([], 1), + "pattern_3": ([1550], 2), + }, + ], +) +@pytest.mark.parametrize( + ("prompt", "expected"), + [ + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=6, + unit=[32000, 32000], + unit_count=1, + ), + ], + ), + ( + [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_1", + start_idx=5, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=7, + unit=[1550], + unit_count=2, + ), + ], + ), + ( + [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=2, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=6, + unit=[1550], + unit_count=2, + ), + ], + ), + ] +) +def test_iter_placeholders( + repl_by_key, + prompt, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement([], *repl).bind(key, mock_tokenizer) + for key, repl in repl_by_key.items() + ] + + result = list(iter_placeholders(prompt_repls, prompt)) + + # Only displayed on error + print("result:", result) # Manually constructed results - assert result_by_mm_count == expected_by_mm_count + assert result == expected diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 68b4756331e6d..927e8b3cc820a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -230,21 +230,38 @@ def dummy_data_for_profiling( This should be called after :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + if mm_registry.has_processor(model_config): + # Avoid circular import + from vllm.multimodal.utils import cached_get_tokenizer + + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + processor = mm_registry.create_processor(model_config, tokenizer) + + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_max_tokens = mm_registry.get_max_tokens_by_modality( + model_config) + + dummy_data = processor.get_dummy_data(seq_len, mm_counts, + mm_max_tokens) else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, overrides=model_config.mm_processor_kwargs) - - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) + + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = dummy_data.seq_data.prompt_token_ids diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 28c8dda581982..15a6b5600f4c3 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,14 +3,13 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from itertools import groupby from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union import numpy as np from transformers import BatchFeature from typing_extensions import TypeAlias, TypedDict -from vllm.inputs import InputProcessingContext +from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import flatten_2d_lists, full_groupby, is_list_of @@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return multi_data -class _TokenRun(NamedTuple): - token_id: int - - start_idx: int - length: int - - -def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: - """ - Yield the starting index and length of each run of tokens that are the same. - """ - start_idx = 0 - - for token_id, it in groupby(token_ids): - length = sum(1 for _ in it) - yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) - - start_idx += length - - -class _PlaceholderInfo(NamedTuple): - modality: str - offset: int - length: int - - def to_range(self) -> PlaceholderRange: - return PlaceholderRange(offset=self.offset, length=self.length) - - -def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement[Any]], - token_ids: list[int], - *, - min_placeholder_count: int, -) -> Iterable[_PlaceholderInfo]: - """Yield each set of placeholder tokens found in :code:`token_ids`.""" - placeholder_ids_by_modality = { - modality: { - token_id - for prompt_repl in repls - for token_id in prompt_repl.repl_unit.token_ids - } - for modality, repls in full_groupby_modality(prompt_repls) - } - - for run_info in iter_token_runs(token_ids): - if run_info.length > min_placeholder_count: - for (modality, - placeholder_ids) in placeholder_ids_by_modality.items(): - if run_info.token_id in placeholder_ids: - yield _PlaceholderInfo( - modality=modality, - offset=run_info.start_idx, - length=run_info.length, - ) - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -353,13 +295,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: raise NotImplementedError + @property @abstractmethod - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> _S: + def repl_unit(self) -> _S: raise NotImplementedError def __repr__(self) -> str: @@ -380,15 +318,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end_idx - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> list[int]: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.token_ids * count + @property + def repl_unit(self) -> list[int]: + return self.prompt_repl.repl_unit.token_ids @dataclass(repr=False) @@ -404,15 +336,41 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end() - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> str: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.text * count + @property + def repl_unit(self) -> str: + return self.prompt_repl.repl_unit.text + + +class _PlaceholderInfo(NamedTuple): + modality: str + start_idx: int + unit: list[int] + unit_count: int + + @property + def length(self) -> int: + return len(self.unit) * self.unit_count + + def can_merge(self, next_: "_PlaceholderInfo") -> bool: + return (self.modality == next_.modality and self.unit == next_.unit + and self.start_idx + self.length == next_.start_idx) + + def merge(self, next_: "_PlaceholderInfo") -> "_PlaceholderInfo": + if not self.can_merge(next_): + raise ValueError(f"Unable to merge {self} and {next_}") + + return _PlaceholderInfo( + modality=self.modality, + start_idx=self.start_idx, + unit=self.unit, + unit_count=self.unit_count + next_.unit_count, + ) + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange( + offset=self.start_idx, + length=self.length, + ) def find_token_matches( @@ -480,9 +438,12 @@ def _replace_matches( start_idx = match.start_idx end_idx = match.end_idx - repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + repl_unit = match.repl_unit + repl_info = match.prompt_repl + repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + out_seqs.append(prompt[prev_end_idx:start_idx] + + repl_unit * repl_count) prev_end_idx = end_idx next_idx_by_modality[modality] += 1 @@ -531,6 +492,48 @@ def replace_text_matches( return "".join(texts) +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt: list[int], + *, + min_unit_count: int = 1, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + if min_unit_count <= 0: + raise ValueError("`min_placeholder_count` must be a positive integer") + + matches = [ + _PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 + for match in iter_token_matches(prompt, repl_unit) + ] + + current_placeholder = None + + for match in _resolve_matches(prompt, matches): + match_placeholder = _PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=match.prompt_repl.repl_unit.token_ids, + unit_count=1, + ) + + if current_placeholder is None: + current_placeholder = match_placeholder + elif current_placeholder.can_merge(match_placeholder): + current_placeholder = current_placeholder.merge(match_placeholder) + else: + if current_placeholder.unit_count >= min_unit_count: + yield current_placeholder + + current_placeholder = match_placeholder + + if (current_placeholder is not None + and current_placeholder.unit_count >= min_unit_count): + yield current_placeholder + + class MultiModalProcessor: """ Helper class to process multi-modal inputs to be used in vLLM. @@ -562,13 +565,13 @@ def _find_placeholders( # To avoid false positives from multi-input when detecting # whether placeholder tokens have been inserted, in case # the target sequence is a subset of the replacement tokens - min_placeholder_count: int = 16, + min_unit_count: int = 16, ) -> list[_PlaceholderInfo]: return list( iter_placeholders( all_prompt_repls, new_token_ids, - min_placeholder_count=min_placeholder_count, + min_unit_count=min_unit_count, )) def _apply_hf_processor( @@ -579,10 +582,15 @@ def _apply_hf_processor( ) -> BatchFeature: hf_processor = self.ctx.get_hf_processor() + # Map keys to plural form, e.g.: image -> images + mm_data = {(k if k.endswith("s") else f"{k}s"): v + for k, v in mm_data.items()} + return hf_processor( text=prompt, # type: ignore **mm_data, **mm_processor_kwargs, + return_tensors="pt", ) def _bind_prompt_replacements( @@ -648,15 +656,6 @@ def _apply_prompt_replacements( placeholders = self._find_placeholders(matched_repls, token_ids) - # Sanity check - assert len(placeholders) == len(matched_repls), dict( - # Log this information for easier debugging - text=text, - token_ids=token_ids, - placeholders=placeholders, - matched_repls=matched_repls, - ) - return token_ids, text, placeholders def apply( @@ -717,3 +716,51 @@ def apply( mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) + + def get_dummy_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_max_tokens: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + tokenizer = self.ctx.tokenizer + + mm_placeholders = dict[str, list[_PlaceholderInfo]]() + offset = 0 + + for modality, num_items in mm_counts.items(): + max_tokens = mm_max_tokens[modality] + if max_tokens == 0: + continue + + metadata = self.metadata[modality] + repl = metadata.prompt_repls[0].bind(modality, tokenizer) + repl_token_ids = repl.repl_unit.token_ids + + placeholders = _PlaceholderInfo( + modality=modality, + start_idx=offset, + unit=repl_token_ids, + unit_count=max_tokens // len(repl_token_ids), + ) + + mm_placeholders[modality] = [placeholders] * num_items + offset += placeholders.length + + prompt_token_ids = flatten_2d_lists([ + p.unit * p.unit_count for placeholders in mm_placeholders.values() + for p in placeholders + ]) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=None, + multi_modal_placeholders={ + modality: [p.to_range() for p in placeholders] + for modality, placeholders in mm_placeholders.items() + }, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b992442d3b314..c1330dfcac51f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -199,9 +199,12 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + def get_max_tokens_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: """ - Get the maximum number of multi-modal tokens + Get the maximum number of tokens from each modality for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. @@ -211,9 +214,23 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ limits_per_plugin = self._limits_by_model[model_config] - return sum((limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items()) + return { + key: (limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items() + } + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return sum(self.get_max_tokens_by_modality(model_config).values()) def init_mm_limits_per_prompt( self, @@ -269,7 +286,8 @@ def register_processor( factory: MultiModalProcessorFactory, ): """ - Register a multi-modal processor to a model class. + Register a multi-modal processor to a model class. The processor + is constructed lazily, hence a factory method should be passed. When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs. From de8332aeb99ad85ad2f2260439b99322b43e2233 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 16:23:13 +0000 Subject: [PATCH 02/21] Use merged processor for llava model Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 91 ++++++++++------------------- vllm/multimodal/registry.py | 27 ++++++++- 2 files changed, 56 insertions(+), 62 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 05c6cc62efcd7..65e8fa74741ac 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,34 +4,34 @@ import torch import torch.nn as nn -from PIL import Image -from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, - PretrainedConfig, SiglipVisionConfig) +from PIL.Image import Image +from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, + PixtralVisionConfig, PretrainedConfig, + SiglipVisionConfig) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) +from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.processing import (InputProcessingContext, + ModalityProcessingMetadata, + MultiModalProcessingMetadata, + PromptReplacement) from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_max_clip_image_tokens, - input_processor_for_clip) + dummy_seq_data_for_clip, get_max_clip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, dummy_seq_data_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - input_processor_for_pixtral_hf) + get_max_pixtral_hf_image_tokens) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens, - input_processor_for_siglip) + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -150,56 +150,26 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config +def create_metadata( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config + image_token_id = hf_config.image_token_index - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_max_llava_image_tokens(ctx) - elif is_list_of(image_data, Image.Image): - image_feature_size = [get_max_llava_image_tokens(ctx) - ] * len(image_data) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, PixtralVisionConfig): - # We ignore image_feature_size_override since we have non-uniform - # image sizes for Pixtral - return input_processor_for_pixtral_hf( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - ) + def get_repl_count( + mm_items: list[Image], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + return get_max_llava_image_tokens(ctx) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[image_token_id], + repl_unit=[image_token_id], + repl_count=get_repl_count), + ]), + } class LlavaLikeConfig(Protocol): @@ -282,10 +252,9 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index c1330dfcac51f..de674e89c9528 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -14,7 +14,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessor +from .processing import MultiModalProcessingMetadata, MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -310,6 +310,31 @@ def wrapper(model_cls: N) -> N: return wrapper + def register_processor_by_metadata( + self, + metadata_factory: Callable[[InputProcessingContext], + MultiModalProcessingMetadata], + ): + """ + Convenience method to register a multi-modal processor to a model class + according to a function that constructs its metadata. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + def factory(ctx: InputProcessingContext): + return MultiModalProcessor( + ctx=ctx, + metadata=metadata_factory(ctx), + ) + + return self.register_processor(factory) + def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. From 8b6804e783ffbbb32f89675b60d481ad6314bd58 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 16:29:18 +0000 Subject: [PATCH 03/21] format Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 15a6b5600f4c3..db57154d8b9ae 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -725,7 +725,7 @@ def get_dummy_data( ) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData - + tokenizer = self.ctx.tokenizer mm_placeholders = dict[str, list[_PlaceholderInfo]]() From 26e3fdfdb3153cd21e16ba22c30edf8c0ad92663 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 16:32:21 +0000 Subject: [PATCH 04/21] Fix typo Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index db57154d8b9ae..c1e0fa7856b2a 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -500,7 +500,7 @@ def iter_placeholders( ) -> Iterable[_PlaceholderInfo]: """Yield each set of placeholder tokens found in :code:`token_ids`.""" if min_unit_count <= 0: - raise ValueError("`min_placeholder_count` must be a positive integer") + raise ValueError("`min_unit_count` must be a positive integer") matches = [ _PromptReplacementTokenMatch(prompt_repl, match) From 93d27bc87671e28435202091d5bb828f876b6213 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 17:03:17 +0000 Subject: [PATCH 05/21] Enable the test to pass on V1 Signed-off-by: DarkLight1337 --- vllm/v1/engine/mm_input_mapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 594c973678235..e80f22e4b1efc 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -12,6 +12,7 @@ def __init__( model_config: ModelConfig, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): + self.model_config = model_config self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry.create_input_mapper( model_config) @@ -22,6 +23,9 @@ def process_inputs( mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]], ) -> List[MultiModalKwargs]: + if self.mm_registry.has_processor(self.model_config): + return [MultiModalKwargs(mm_data)] # Already processed + image_inputs = mm_data["image"] if not isinstance(image_inputs, list): image_inputs = [image_inputs] From d697241f1e3984008fa1fdbb74f1c0089dcd8ac4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 17:28:53 +0000 Subject: [PATCH 06/21] Handle embedding inputs Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c1e0fa7856b2a..9fe2770551ba0 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -6,6 +6,7 @@ from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union import numpy as np +import torch from transformers import BatchFeature from typing_extensions import TypeAlias, TypedDict @@ -581,17 +582,34 @@ def _apply_hf_processor( mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: hf_processor = self.ctx.get_hf_processor() - - # Map keys to plural form, e.g.: image -> images - mm_data = {(k if k.endswith("s") else f"{k}s"): v - for k, v in mm_data.items()} - - return hf_processor( + + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + for k, v in mm_data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + if k in ("image", "video", "audio"): + if isinstance(v, torch.Tensor) and v.ndim == 3: + # Pass through embedding inputs (single) + passthrough_data[f"{k}_embeds"] = [v] + elif is_list_of(v, torch.Tensor) and v[0].ndim == 3: + # Pass through embedding inputs (multi) + passthrough_data[f"{k}_embeds"] = v + else: + # Map keys to plural form, e.g.: image -> images + processor_data[f"{k}s"] = v + else: + processor_data[k] = v + + hf_inputs = hf_processor( text=prompt, # type: ignore - **mm_data, + **processor_data, **mm_processor_kwargs, return_tensors="pt", ) + hf_inputs.update(passthrough_data) + + return hf_inputs def _bind_prompt_replacements( self, From ca11cc944d8b78181812cd180ee343b10139550a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 26 Nov 2024 17:33:57 +0000 Subject: [PATCH 07/21] format Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 9fe2770551ba0..19c4873336122 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -582,7 +582,7 @@ def _apply_hf_processor( mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: hf_processor = self.ctx.get_hf_processor() - + processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() for k, v in mm_data.items(): From 6c5c9ca8202c40c3871c7a0f5f082a057a3e594c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 02:39:55 +0000 Subject: [PATCH 08/21] Fix wrong ndim Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 19c4873336122..724e5efc6fb03 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -592,7 +592,7 @@ def _apply_hf_processor( if isinstance(v, torch.Tensor) and v.ndim == 3: # Pass through embedding inputs (single) passthrough_data[f"{k}_embeds"] = [v] - elif is_list_of(v, torch.Tensor) and v[0].ndim == 3: + elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v else: From 01943242a89d6c7d4268348e96a2d62ab8ae69cb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 02:52:46 +0000 Subject: [PATCH 09/21] Factor out `merge_placeholders` Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 49 +++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 724e5efc6fb03..37aae835938ec 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -493,6 +493,24 @@ def replace_text_matches( return "".join(texts) +def _merge_placeholders( + placeholders: Iterable[_PlaceholderInfo] +) -> Iterable[_PlaceholderInfo]: + current_placeholder = None + + for placeholder in placeholders: + if current_placeholder is None: + current_placeholder = placeholder + elif current_placeholder.can_merge(placeholder): + current_placeholder = current_placeholder.merge(placeholder) + else: + yield current_placeholder + current_placeholder = placeholder + + if current_placeholder is not None: + yield current_placeholder + + def iter_placeholders( prompt_repls: Sequence[_BoundPromptReplacement[Any]], prompt: list[int], @@ -510,29 +528,16 @@ def iter_placeholders( for match in iter_token_matches(prompt, repl_unit) ] - current_placeholder = None + match_placeholders = (_PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=match.prompt_repl.repl_unit.token_ids, + unit_count=1, + ) for match in _resolve_matches(prompt, matches)) - for match in _resolve_matches(prompt, matches): - match_placeholder = _PlaceholderInfo( - modality=match.modality, - start_idx=match.start_idx, - unit=match.prompt_repl.repl_unit.token_ids, - unit_count=1, - ) - - if current_placeholder is None: - current_placeholder = match_placeholder - elif current_placeholder.can_merge(match_placeholder): - current_placeholder = current_placeholder.merge(match_placeholder) - else: - if current_placeholder.unit_count >= min_unit_count: - yield current_placeholder - - current_placeholder = match_placeholder - - if (current_placeholder is not None - and current_placeholder.unit_count >= min_unit_count): - yield current_placeholder + for placeholder in _merge_placeholders(match_placeholders): + if placeholder.unit_count >= min_unit_count: + yield placeholder class MultiModalProcessor: From 09618d0710533c15b4c01c50e8e252b067edd56d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 05:54:32 +0000 Subject: [PATCH 10/21] Fix placeholder maps handling on V0 Signed-off-by: DarkLight1337 --- vllm/multimodal/base.py | 45 ++++++++++++++++++++++--------- vllm/v1/engine/mm_input_mapper.py | 5 ++-- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6eec660e42ac4..b5c19cb00e49b 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -326,26 +326,47 @@ def from_seq_group( src_ranges = [] dest_ranges = [] """ - if (not seq_group.multi_modal_data - or not seq_group.multi_modal_placeholders): - return seq_group.multi_modal_data, {} + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders - mm_data = {**seq_group.multi_modal_data} - placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps + + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( MultiModalPlaceholderMap) - for ( - modality, - placeholders, - ) in seq_group.multi_modal_placeholders.items(): + for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) if not isinstance(mm_items, list): mm_items = [mm_items] if positions: - intersecting_items = placeholder_maps[ - modality].append_items_from_seq_group( - positions, mm_items, placeholders) + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) if intersecting_items: mm_data[modality] = intersecting_items diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index e80f22e4b1efc..f056e60908cbb 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -23,8 +23,9 @@ def process_inputs( mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]], ) -> List[MultiModalKwargs]: - if self.mm_registry.has_processor(self.model_config): - return [MultiModalKwargs(mm_data)] # Already processed + # Skip this redundant step if merged processor has been applied + if isinstance(mm_data, MultiModalKwargs): + return [mm_data] image_inputs = mm_data["image"] if not isinstance(image_inputs, list): From 5501458449398f552f1d9c25c684d20b83221187 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 06:15:36 +0000 Subject: [PATCH 11/21] Remove unused dummy data code Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 63 +++-------------------------- 1 file changed, 6 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0436ea3d0f67e..eb2f62c845abe 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,6 +1,6 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, - Tuple, TypedDict, Union) +from typing import (Iterable, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -11,7 +11,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -24,14 +24,10 @@ PromptReplacement) from vllm.sequence import IntermediateTensors -from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_max_clip_image_tokens) +from .clip import CLIPVisionModel, get_max_clip_image_tokens from .interfaces import SupportsMultiModal, SupportsPP -from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - dummy_seq_data_for_pixtral_hf, - get_max_pixtral_hf_image_tokens) -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens) +from .pixtral import PixtralHFVisionModel, get_max_pixtral_hf_image_tokens +from .siglip import SiglipVisionModel, get_max_siglip_image_tokens from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -104,52 +100,6 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_data_for_llava(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - image_feature_size = get_max_llava_image_tokens(ctx) - - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_clip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_siglip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, PixtralVisionConfig): - seq_data, ranges = dummy_seq_data_for_pixtral_hf( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - def create_metadata( ctx: InputProcessingContext) -> MultiModalProcessingMetadata: hf_config = ctx.get_hf_config(LlavaConfig) @@ -253,7 +203,6 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): From f3673c7f35947ac9cdf46fbbef2d19e41ab7467d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 06:19:00 +0000 Subject: [PATCH 12/21] Update dummy model Signed-off-by: DarkLight1337 --- .../vllm_add_dummy_model/my_llava.py | 10 +++------- vllm/model_executor/models/llava.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 3ebd7864b8fc8..5fa86896c6b6b 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,19 +2,15 @@ import torch -from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - dummy_data_for_llava, - get_max_llava_image_tokens, - input_processor_for_llava) + create_metadata_for_llava, + get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index eb2f62c845abe..c1159ab8e4589 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -100,7 +100,7 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def create_metadata( +def create_metadata_for_llava( ctx: InputProcessingContext) -> MultiModalProcessingMetadata: hf_config = ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index @@ -203,7 +203,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata) +@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: From 37bc0085dcb761877a5b9a2162e132c7ce4b5b4b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 07:17:06 +0000 Subject: [PATCH 13/21] Enable overriding hf processor and tokenizer; fix `_apply_prompt_replacements` Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 40 +++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 37aae835938ec..febe02557d562 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -7,7 +7,7 @@ import numpy as np import torch -from transformers import BatchFeature +from transformers import BatchFeature, ProcessorMixin from typing_extensions import TypeAlias, TypedDict from vllm.inputs import DummyData, InputProcessingContext @@ -555,6 +555,12 @@ def __init__( self.ctx = ctx self.metadata = metadata + def _get_hf_processor(self) -> ProcessorMixin: + return self.ctx.get_hf_processor() + + def _get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + def __call__( self, prompt: str, @@ -586,7 +592,7 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self.ctx.get_hf_processor() + hf_processor = self._get_hf_processor() processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() @@ -606,12 +612,20 @@ def _apply_hf_processor( else: processor_data[k] = v - hf_inputs = hf_processor( - text=prompt, # type: ignore - **processor_data, - **mm_processor_kwargs, - return_tensors="pt", - ) + try: + hf_inputs = hf_processor( + text=prompt, # type: ignore + **processor_data, + **mm_processor_kwargs, + return_tensors="pt", + ) + except Exception as exc: + data = dict(text=prompt, **processor_data) + + raise RuntimeError( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={mm_processor_kwargs}") from exc + hf_inputs.update(passthrough_data) return hf_inputs @@ -620,7 +634,7 @@ def _bind_prompt_replacements( self, mm_data: MultiModalDataDict, ) -> list[_BoundPromptReplacement[Any]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() return [ prompt_repl.bind(modality, tokenizer) @@ -635,7 +649,7 @@ def _apply_prompt_replacements( token_ids: list[int], prompt_repls: Sequence[_BoundPromptReplacement[Any]], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() mm_items = to_multi_format(mm_data) token_matches = find_token_matches(token_ids, prompt_repls) @@ -651,7 +665,7 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= len(mm_data[modality]) + len(matches) >= len(mm_items[modality]) for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( @@ -700,7 +714,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() hf_inputs = self._apply_hf_processor(prompt_text, mm_data, mm_processor_kwargs) @@ -749,7 +763,7 @@ def get_dummy_data( # Avoid circular import from vllm.sequence import SequenceData - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() mm_placeholders = dict[str, list[_PlaceholderInfo]]() offset = 0 From 4805a9e0e0252edfae4a03296bcc8ca505ce8478 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 13:04:36 +0000 Subject: [PATCH 14/21] Improve error handling in `_resolve_matches`; merge matches directly Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 88 +++++++++++++++-------------------- 1 file changed, 38 insertions(+), 50 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index febe02557d562..5fd5629376234 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -5,7 +5,6 @@ from functools import lru_cache from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union -import numpy as np import torch from transformers import BatchFeature, ProcessorMixin from typing_extensions import TypeAlias, TypedDict @@ -352,21 +351,6 @@ class _PlaceholderInfo(NamedTuple): def length(self) -> int: return len(self.unit) * self.unit_count - def can_merge(self, next_: "_PlaceholderInfo") -> bool: - return (self.modality == next_.modality and self.unit == next_.unit - and self.start_idx + self.length == next_.start_idx) - - def merge(self, next_: "_PlaceholderInfo") -> "_PlaceholderInfo": - if not self.can_merge(next_): - raise ValueError(f"Unable to merge {self} and {next_}") - - return _PlaceholderInfo( - modality=self.modality, - start_idx=self.start_idx, - unit=self.unit, - unit_count=self.unit_count + next_.unit_count, - ) - def to_range(self) -> PlaceholderRange: return PlaceholderRange( offset=self.start_idx, @@ -406,15 +390,16 @@ def _resolve_matches( Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - num_matches_by_idx = np.zeros(len(prompt), dtype=int) + seen_matches = dict[int, _PromptReplacementMatch[_T, _S]]() + for match in matches: - num_matches_by_idx[match.start_idx:match.end_idx] += 1 + for idx in range(match.start_idx, match.end_idx): + if idx in seen_matches: + raise ValueError("Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}") - duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) - if len(duplicate_matches_idxs) > 0: - raise ValueError("Unable to find a unique replacement " - f"at indices={duplicate_matches_idxs} " - f"of prompt={prompt}") + seen_matches[idx] = match return sorted(matches, key=lambda x: x.start_idx) @@ -493,22 +478,26 @@ def replace_text_matches( return "".join(texts) -def _merge_placeholders( - placeholders: Iterable[_PlaceholderInfo] -) -> Iterable[_PlaceholderInfo]: - current_placeholder = None +def _merge_placeholder_matches( + matches: Iterable[_PromptReplacementTokenMatch], +) -> Iterable[_PromptReplacementTokenMatch]: + current_match = None - for placeholder in placeholders: - if current_placeholder is None: - current_placeholder = placeholder - elif current_placeholder.can_merge(placeholder): - current_placeholder = current_placeholder.merge(placeholder) + for match in sorted(matches, key=lambda x: x.start_idx): + if current_match is None: + current_match = match + elif (current_match.prompt_repl == match.prompt_repl + and current_match.end_idx == match.start_idx): + current_match = _PromptReplacementTokenMatch( + current_match.prompt_repl, + match=_TokenMatch(current_match.start_idx, match.end_idx), + ) else: - yield current_placeholder - current_placeholder = placeholder + yield current_match + current_match = match - if current_placeholder is not None: - yield current_placeholder + if current_match is not None: + yield current_match def iter_placeholders( @@ -521,21 +510,20 @@ def iter_placeholders( if min_unit_count <= 0: raise ValueError("`min_unit_count` must be a positive integer") - matches = [ - _PromptReplacementTokenMatch(prompt_repl, match) - for prompt_repl in prompt_repls - if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 - for match in iter_token_matches(prompt, repl_unit) - ] - - match_placeholders = (_PlaceholderInfo( - modality=match.modality, - start_idx=match.start_idx, - unit=match.prompt_repl.repl_unit.token_ids, - unit_count=1, - ) for match in _resolve_matches(prompt, matches)) + matches = (_PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 + for match in iter_token_matches(prompt, repl_unit)) + + for match in _merge_placeholder_matches(matches): + unit = match.repl_unit + placeholder = _PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=unit, + unit_count=(match.end_idx - match.start_idx) // len(unit), + ) - for placeholder in _merge_placeholders(match_placeholders): if placeholder.unit_count >= min_unit_count: yield placeholder From 85390084c57f84bcb4269e7624e51bce266d1467 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 13:08:27 +0000 Subject: [PATCH 15/21] Avoid hashing Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5fd5629376234..286bfef077e48 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -390,11 +390,12 @@ def _resolve_matches( Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - seen_matches = dict[int, _PromptReplacementMatch[_T, _S]]() + seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ + = [None] * len(prompt) for match in matches: for idx in range(match.start_idx, match.end_idx): - if idx in seen_matches: + if seen_matches[idx] is not None: raise ValueError("Found overlapping matches " f"({seen_matches[idx]} and {match}) " f"at index={idx} of prompt={prompt}") From 1e82a4ad9997220cd2ff0675c13e3aa5c76529ca Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 13:15:30 +0000 Subject: [PATCH 16/21] Support and test Mantis model Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 6 ++-- examples/offline_inference_vision_language.py | 17 +++++++++ requirements-test.in | 3 -- tests/conftest.py | 1 - .../vision_language/test_models.py | 28 +++++++++++---- .../vision_language/vlm_utils/core.py | 16 ++++++--- .../vision_language/vlm_utils/model_utils.py | 35 ++++++++++++++++++- .../vision_language/vlm_utils/types.py | 16 +++++---- tests/models/registry.py | 1 + vllm/model_executor/models/llava.py | 32 +++++++++++++++-- vllm/model_executor/models/registry.py | 1 + 11 files changed, 129 insertions(+), 27 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fc23c9cff0d87..10a89f45e3d55 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -334,7 +334,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model - - pytest -v -s models/embedding/vision_language -m core_model - label: Language Models Test (Extended) # 50min optional: true @@ -346,7 +345,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' - - pytest -v -s models/embedding/vision_language -m 'not core_model' - label: Multi-Modal Models Test (Standard) # 26min #mirror_hardwares: [amd] @@ -357,8 +355,10 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model @@ -371,11 +371,13 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index f08f22eec164a..e7ca816503c1b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str): return llm, prompt, stop_token_ids +# Mantis +def run_mantis(question: str, modality: str): + assert modality == "image" + + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 + prompt = llama3_template.format(f"{question}\n") + + llm = LLM( + model="TIGER-Lab/Mantis-8B-siglip-llama3", + max_model_len=4096, + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ) + stop_token_ids = [128009] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str): "glm4v": run_glm4v, "idefics3": run_idefics3, "aria": run_aria, + "mantis": run_mantis, } diff --git a/requirements-test.in b/requirements-test.in index 76f6de2f77c34..1ea0c44ee1035 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.4.4 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -# TODO: Add this after fully implementing llava(mantis) -# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test - # quantization bitsandbytes>=0.44.0 buildkite-test-collector==0.1.9 diff --git a/tests/conftest.py b/tests/conftest.py index d56942d8912af..36f1d477fab59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -263,7 +263,6 @@ def __init__( dtype: str = "half", *, model_kwargs: Optional[Dict[str, Any]] = None, - is_embedding_model: bool = False, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3f6d8ef42cd5f..0bd6e97019258 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -34,7 +34,7 @@ "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, - "model_kwargs": {"device_map": "auto"}, + "hf_model_kwargs": {"device_map": "auto"}, "image_size_factors": [(.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", @@ -108,7 +108,7 @@ "cherry_blossom": "What is in the picture?", }), auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, @@ -148,7 +148,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), # For chameleon, we only compare the sequences @@ -264,7 +264,7 @@ prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values_videos" ), auto_cls=AutoModelForVision2Seq, @@ -295,6 +295,20 @@ ) ], ), + "mantis": VLMTestInfo( + models=["TIGER-Lab/Mantis-8B-siglip-llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + max_model_len=4096, + postprocess_inputs=model_utils.cast_dtype_post_processor( + "pixel_values" + ), + vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501 + get_stop_token_ids=lambda tok: [128009], + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, + patch_hf_runner=model_utils.mantis_patch_hf_runner, + ), "minicpmv": VLMTestInfo( models=["openbmb/MiniCPM-Llama3-V-2_5"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -318,7 +332,7 @@ # max_num_seqs=2, # task="generate", # # use eager mode for hf runner since phi3v didn't work with flash_attn - # model_kwargs={"_attn_implementation": "eager"}, + # hf_model_kwargs={"_attn_implementation": "eager"}, # use_tokenizer_eos=True, # vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, # num_logprobs=10, @@ -349,7 +363,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], @@ -418,7 +432,7 @@ test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), auto_cls=AutoModelForVision2Seq, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 7e8c6dabb15af..c161036312d89 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -3,7 +3,7 @@ import torch from PIL.Image import Image -from transformers import AutoTokenizer, BatchEncoding +from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase from transformers.models.auto.auto_factory import _BaseAutoModelClass from .....conftest import HfRunner, VllmRunner @@ -28,9 +28,11 @@ def run_test( use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]], limit_mm_per_prompt: Dict[str, int], - model_kwargs: Optional[Dict[str, Any]], + vllm_runner_kwargs: Optional[Dict[str, Any]], + hf_model_kwargs: Optional[Dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], task: str = "auto", runner_mm_key: str = "images", @@ -54,6 +56,9 @@ def run_test( if get_stop_token_ids is not None: vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) + if vllm_runner_kwargs is None: + vllm_runner_kwargs = {} + with vllm_runner(model, max_model_len=max_model_len, max_num_seqs=max_num_seqs, @@ -62,7 +67,8 @@ def run_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=enforce_eager, - task=task) as vllm_model: + task=task, + **vllm_runner_kwargs) as vllm_model: for prompts, media in vllm_inputs: vllm_kwargs[runner_mm_key] = media vllm_output = vllm_model.generate_greedy_logprobs( @@ -73,7 +79,7 @@ def run_test( dtype=dtype, auto_cls=auto_cls, postprocess_inputs=postprocess_inputs, - model_kwargs=model_kwargs) + model_kwargs=hf_model_kwargs) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 849857b4232e7..2f238b4e18a74 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, + model: str) -> RunnerOutput: + """Sanitize vllm output [mantis] to compare with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|eot_id|>" + + return output_ids, hf_output_str, out_logprobs + + def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets): ####### postprocessors to run on HF BatchEncoding -def get_key_type_post_processor( +def cast_dtype_post_processor( hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: """Gets a handle to a post processor which converts a given key into a target data type.""" @@ -407,3 +417,26 @@ def _internvl_generate( ) return outputs + + +def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + from mantis.models.mllava import MLlavaProcessor + + hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name) + + orig_generate = hf_model.model.generate + tokenizer = hf_model.processor.tokenizer + + def _generate(self, *args, **kwargs): + return orig_generate( + *args, + **kwargs, + eos_token_id=[ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ], + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index 8459476dc2d07..ad5095a3eaee1 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -7,7 +7,8 @@ import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import (AutoModelForCausalLM, BatchEncoding, + PreTrainedTokenizerBase) from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.sequence import SampleLogprobs @@ -66,7 +67,7 @@ class ImageSizeWrapper(NamedTuple): class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" - models: Union[List[str]] + models: List[str] test_type: Union[VLMTestType, Iterable[VLMTestType]] # Should be None only if this is a CUSTOM_INPUTS test @@ -94,13 +95,15 @@ class VLMTestInfo(NamedTuple): max_num_seqs: int = 256 task: str = "auto" tensor_parallel_size: int = 1 + vllm_runner_kwargs: Optional[Dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]] = None # Exposed options for HF runner - model_kwargs: Optional[Dict[str, Any]] = None - # Indicates we should explicitly pass the EOS from the tokeniezr + hf_model_kwargs: Optional[Dict[str, Any]] = None + # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM # Callable to pass to the HF runner to run on inputs; for now, we also pass @@ -159,6 +162,7 @@ def get_non_parametrized_runner_kwargs(self): "max_num_seqs": self.max_num_seqs, "task": self.task, "tensor_parallel_size": self.tensor_parallel_size, + "vllm_runner_kwargs": self.vllm_runner_kwargs, "hf_output_post_proc": self.hf_output_post_proc, "vllm_output_post_proc": self.vllm_output_post_proc, "auto_cls": self.auto_cls, @@ -166,7 +170,7 @@ def get_non_parametrized_runner_kwargs(self): "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, - "model_kwargs": self.model_kwargs, + "hf_model_kwargs": self.hf_model_kwargs, "patch_hf_runner": self.patch_hf_runner, } diff --git a/tests/models/registry.py b/tests/models/registry.py index 865e90b3f8b0e..2fa61e5dd0095 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -173,6 +173,7 @@ class _HfExamplesInfo: "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501 "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c1159ab8e4589..86918aec4c43d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -7,7 +7,7 @@ from PIL.Image import Image from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, - SiglipVisionConfig) + ProcessorMixin, SiglipVisionConfig) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig @@ -21,7 +21,7 @@ from vllm.multimodal.processing import (InputProcessingContext, ModalityProcessingMetadata, MultiModalProcessingMetadata, - PromptReplacement) + MultiModalProcessor, PromptReplacement) from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel, get_max_clip_image_tokens @@ -499,3 +499,31 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class MantisProcessor(MultiModalProcessor): + + def _get_hf_processor(self) -> ProcessorMixin: + try: + from mantis.models.mllava import MLlavaProcessor + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "You need to `pip install " + "git+https://github.com/TIGER-AI-Lab/Mantis.git` " + "to use this model") from exc + + processor = MLlavaProcessor.from_pretrained( + self.ctx.model_config.tokenizer) + assert isinstance(processor, ProcessorMixin) + return processor + + +# To use this model, please use +# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@MULTIMODAL_REGISTRY.register_processor(lambda ctx: MantisProcessor( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), +)) +class MantisForConditionalGeneration(LlavaForConditionalGeneration): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f5a02a5b25ca2..bae4acb03b711 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -147,6 +147,7 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), From cfbece41aeae2b482f8598758d47a5261d0a9b78 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 27 Nov 2024 13:42:57 +0000 Subject: [PATCH 17/21] Update docs Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b5cbe6915d581..495c434cabf90 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -532,7 +532,7 @@ Text Generation * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ * - :code:`LlavaNextForConditionalGeneration` @@ -626,6 +626,10 @@ Text Generation .. note:: vLLM currently only supports adding LoRA to the language backbone of multimodal models. +.. note:: + To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) + and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + .. note:: The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 From 70c87d143dd0796bde294b792443263da89b2067 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 7 Dec 2024 09:47:10 +0000 Subject: [PATCH 18/21] Fix type error Signed-off-by: DarkLight1337 --- tests/models/decoder_only/vision_language/vlm_utils/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index 04b7f0afd355d..e2e0c6390fcb9 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -11,6 +11,7 @@ PreTrainedTokenizerBase) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.utils import identity @@ -93,7 +94,7 @@ class VLMTestInfo(NamedTuple): enforce_eager: bool = True max_model_len: int = 1024 max_num_seqs: int = 256 - task: str = "auto" + task: TaskOption = "auto" tensor_parallel_size: int = 1 vllm_runner_kwargs: Optional[Dict[str, Any]] = None From 27b276b373fdd1ab0ed3316584a6f990726eb67c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 7 Dec 2024 09:47:40 +0000 Subject: [PATCH 19/21] Fix redundant code Signed-off-by: DarkLight1337 --- vllm/v1/engine/mm_input_mapper.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index f056e60908cbb..45882f8f076d4 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -23,10 +23,6 @@ def process_inputs( mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]], ) -> List[MultiModalKwargs]: - # Skip this redundant step if merged processor has been applied - if isinstance(mm_data, MultiModalKwargs): - return [mm_data] - image_inputs = mm_data["image"] if not isinstance(image_inputs, list): image_inputs = [image_inputs] From c10c1ccfa4d30e5cff61a21353d0c278f4928bb3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 7 Dec 2024 09:55:56 +0000 Subject: [PATCH 20/21] Remove convenience function as it makes things more complicated Signed-off-by: DarkLight1337 --- .../vllm_add_dummy_model/my_llava.py | 6 +-- vllm/model_executor/models/llava.py | 48 ++++++++++++++----- vllm/multimodal/processing.py | 4 +- vllm/multimodal/registry.py | 41 ++-------------- 4 files changed, 43 insertions(+), 56 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index f2fc0755cae01..2f4194a63fc25 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,16 +3,14 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - create_metadata_for_llava, - dummy_mm_kwargs_for_llava, + LlavaProcessor, get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava, - dummy_mm_kwargs_for_llava) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7f3b28f333344..65c6bd07bfff0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,10 +22,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.processing import (InputProcessingContext, +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, ModalityProcessingMetadata, MultiModalProcessingMetadata, - MultiModalProcessor, PromptReplacement) + PromptReplacement) from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -163,7 +164,13 @@ def get_repl_count( } -class LlavaProcessor(MultiModalProcessor): +class LlavaProcessor(BaseMultiModalProcessor): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), + ) def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): if getattr(hf_processor, "__is_patched__", False): @@ -193,7 +200,30 @@ def _get_dummy_mm_kwargs( self, mm_counts: Mapping[str, int], ) -> MultiModalKwargs: - return dummy_mm_kwargs_for_llava(self.ctx, mm_counts) + hf_config = self.ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + if isinstance(vision_config, CLIPVisionConfig): + data = dummy_image_for_clip(vision_config, num_images) + elif isinstance(vision_config, SiglipVisionConfig): + data = dummy_image_for_siglip(vision_config, num_images) + elif isinstance(vision_config, PixtralVisionConfig): + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + hf_processor = self._get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], + return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) class LlavaLikeConfig(Protocol): @@ -277,10 +307,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor( - ctx=ctx, - metadata=create_metadata_for_llava(ctx), -)) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -581,9 +608,6 @@ def _get_hf_processor(self) -> ProcessorMixin: # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(lambda ctx: MantisProcessor( - ctx=ctx, - metadata=create_metadata_for_llava(ctx), -)) +@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 4a1737991534f..c3a95d60e6fe6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -529,9 +529,9 @@ def iter_placeholders( yield placeholder -class MultiModalProcessor(ABC): +class BaseMultiModalProcessor(ABC): """ - Helper class to process multi-modal inputs to be used in vLLM. + Abstract base class to process multi-modal inputs to be used in vLLM. """ def __init__( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f51da8972d15b..6ab6c0fe2f12e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,7 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessingMetadata, MultiModalProcessor +from .processing import BaseMultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -26,7 +26,7 @@ N = TypeVar("N", bound=Type[nn.Module]) MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - MultiModalProcessor] + BaseMultiModalProcessor] """ Constructs a :class:`MultiModalProcessor` instance from the context. @@ -311,41 +311,6 @@ def wrapper(model_cls: N) -> N: return wrapper - def register_processor_by_metadata( - self, - metadata_factory: Callable[[InputProcessingContext], - MultiModalProcessingMetadata], - get_dummy_mm_kwargs: Callable[ - [InputProcessingContext, Mapping[str, int]], MultiModalKwargs], - ): - """ - Convenience method to register a multi-modal processor to a model class - according to a function that constructs its metadata. - - When the model receives multi-modal data, the provided function is - invoked to transform the data into a dictionary of model inputs. - - See also: - - :ref:`input_processing_pipeline` - - :ref:`enabling_multimodal_inputs` - """ - - class ConcreteMultiModalProcessor(MultiModalProcessor): - - def _get_dummy_mm_kwargs( - self, - mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: - return get_dummy_mm_kwargs(self.ctx, mm_counts) - - def factory(ctx: InputProcessingContext): - return ConcreteMultiModalProcessor( - ctx=ctx, - metadata=metadata_factory(ctx), - ) - - return self.register_processor(factory) - def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. @@ -360,7 +325,7 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> MultiModalProcessor: + ) -> BaseMultiModalProcessor: """ Create a multi-modal processor for a specific model and tokenizer. """ From d77cadd1ce23aecf8a4e38f694e879d29562234f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 7 Dec 2024 15:11:39 +0000 Subject: [PATCH 21/21] Fix commands Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9deec869a20ca..8f57006214c88 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -362,7 +362,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: - - python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model @@ -378,7 +378,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: - - python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307