diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 936e284d9675a..8f57006214c88 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -362,6 +362,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model @@ -377,6 +378,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d915def588e08..c9b3fa8485ff1 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -555,7 +555,7 @@ Text Generation * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ * - :code:`LlavaNextForConditionalGeneration` @@ -664,6 +664,10 @@ Text Generation .. note:: vLLM currently only supports adding LoRA to the language backbone of multimodal models. +.. note:: + To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) + and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + .. note:: The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 56209c3c36ed4..c6a274ee5894b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str): return llm, prompt, stop_token_ids +# Mantis +def run_mantis(question: str, modality: str): + assert modality == "image" + + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 + prompt = llama3_template.format(f"{question}\n") + + llm = LLM( + model="TIGER-Lab/Mantis-8B-siglip-llama3", + max_model_len=4096, + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ) + stop_token_ids = [128009] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str): "glm4v": run_glm4v, "idefics3": run_idefics3, "aria": run_aria, + "mantis": run_mantis, } diff --git a/requirements-test.in b/requirements-test.in index 44972866ddc4b..c0b228148ab31 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -# TODO: Add this after fully implementing llava(mantis) -# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test - # quantization bitsandbytes>=0.44.0 buildkite-test-collector==0.1.9 diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 924f19c4448b8..ed8f34a677f84 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -34,7 +34,7 @@ "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, - "model_kwargs": {"device_map": "auto"}, + "hf_model_kwargs": {"device_map": "auto"}, "image_size_factors": [(.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", @@ -108,7 +108,7 @@ "cherry_blossom": "What is in the picture?", }), auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, @@ -151,7 +151,7 @@ "cherry_blossom": "Please infer the season with reason.", }), multi_image_prompt="Describe the two images shortly.", # noqa: E501 - postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"), + postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), stop_str=["<|im_end|>"], image_size_factors=[(0.10, 0.15)], max_tokens=64, @@ -177,7 +177,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), # For chameleon, we only compare the sequences @@ -281,7 +281,7 @@ prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values_videos" ), auto_cls=AutoModelForVision2Seq, @@ -306,6 +306,20 @@ vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], ), + "mantis": VLMTestInfo( + models=["TIGER-Lab/Mantis-8B-siglip-llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + max_model_len=4096, + postprocess_inputs=model_utils.cast_dtype_post_processor( + "pixel_values" + ), + vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501 + get_stop_token_ids=lambda tok: [128009], + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, + patch_hf_runner=model_utils.mantis_patch_hf_runner, + ), "minicpmv_25": VLMTestInfo( models=["openbmb/MiniCPM-Llama3-V-2_5"], test_type=VLMTestType.IMAGE, @@ -342,7 +356,7 @@ # max_num_seqs=2, # task="generate", # # use eager mode for hf runner since phi3v didn't work with flash_attn - # model_kwargs={"_attn_implementation": "eager"}, + # hf_model_kwargs={"_attn_implementation": "eager"}, # use_tokenizer_eos=True, # vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, # num_logprobs=10, @@ -373,7 +387,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], @@ -438,7 +452,7 @@ test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), auto_cls=AutoModelForVision2Seq, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 88349ef9a3a69..54b7b0733210f 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -3,9 +3,11 @@ import torch from PIL.Image import Image -from transformers import AutoTokenizer, BatchEncoding +from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption + from .....conftest import HfRunner, VllmRunner from .types import RunnerOutput @@ -28,13 +30,15 @@ def run_test( use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]], stop_str: Optional[List[str]], tokenizer_mode: str, limit_mm_per_prompt: Dict[str, int], - model_kwargs: Optional[Dict[str, Any]], + vllm_runner_kwargs: Optional[Dict[str, Any]], + hf_model_kwargs: Optional[Dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], - task: str = "auto", + task: TaskOption = "auto", runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, @@ -58,6 +62,9 @@ def run_test( if stop_str: vllm_kwargs["stop"] = stop_str + if vllm_runner_kwargs is None: + vllm_runner_kwargs = {} + with vllm_runner(model, tokenizer_mode=tokenizer_mode, max_model_len=max_model_len, @@ -67,7 +74,8 @@ def run_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=enforce_eager, - task=task) as vllm_model: + task=task, + **vllm_runner_kwargs) as vllm_model: for prompts, media in vllm_inputs: vllm_kwargs[runner_mm_key] = media vllm_output = vllm_model.generate_greedy_logprobs( @@ -78,7 +86,7 @@ def run_test( dtype=dtype, auto_cls=auto_cls, postprocess_inputs=postprocess_inputs, - model_kwargs=model_kwargs) + model_kwargs=hf_model_kwargs) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 15f15dd7d8030..3eca8fb9dcb1a 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, + model: str) -> RunnerOutput: + """Sanitize vllm output [mantis] to compare with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|eot_id|>" + + return output_ids, hf_output_str, out_logprobs + + def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets): ####### postprocessors to run on HF BatchEncoding -def get_key_type_post_processor( +def cast_dtype_post_processor( hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: """Gets a handle to a post processor which converts a given key into a target data type.""" @@ -418,3 +428,26 @@ def _internvl_generate( ) return outputs + + +def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + from mantis.models.mllava import MLlavaProcessor + + hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name) + + orig_generate = hf_model.model.generate + tokenizer = hf_model.processor.tokenizer + + def _generate(self, *args, **kwargs): + return orig_generate( + *args, + **kwargs, + eos_token_id=[ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ], + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index d410fa8c653ce..e2e0c6390fcb9 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -7,9 +7,11 @@ import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import (AutoModelForCausalLM, BatchEncoding, + PreTrainedTokenizerBase) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.utils import identity @@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple): class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" - models: Union[List[str]] + models: List[str] test_type: Union[VLMTestType, Iterable[VLMTestType]] # Should be None only if this is a CUSTOM_INPUTS test @@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple): enforce_eager: bool = True max_model_len: int = 1024 max_num_seqs: int = 256 - task: str = "auto" + task: TaskOption = "auto" tensor_parallel_size: int = 1 + vllm_runner_kwargs: Optional[Dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]] = None # Optional list of strings to stop generation, useful when stop tokens are # not special tokens in the tokenizer stop_str: Optional[List[str]] = None # Exposed options for HF runner - model_kwargs: Optional[Dict[str, Any]] = None - # Indicates we should explicitly pass the EOS from the tokeniezr + hf_model_kwargs: Optional[Dict[str, Any]] = None + # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM # Callable to pass to the HF runner to run on inputs; for now, we also pass @@ -164,6 +168,7 @@ def get_non_parametrized_runner_kwargs(self): "max_num_seqs": self.max_num_seqs, "task": self.task, "tensor_parallel_size": self.tensor_parallel_size, + "vllm_runner_kwargs": self.vllm_runner_kwargs, "hf_output_post_proc": self.hf_output_post_proc, "vllm_output_post_proc": self.vllm_output_post_proc, "auto_cls": self.auto_cls, @@ -171,8 +176,8 @@ def get_non_parametrized_runner_kwargs(self): "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, + "hf_model_kwargs": self.hf_model_kwargs, "stop_str": self.stop_str, - "model_kwargs": self.model_kwargs, "patch_hf_runner": self.patch_hf_runner, "tokenizer_mode": self.tokenizer_mode } diff --git a/tests/models/registry.py b/tests/models/registry.py index 461f453d8b1c3..a89518820045f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -176,6 +176,7 @@ class _HfExamplesInfo: "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501 "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index f2fc0755cae01..2f4194a63fc25 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,16 +3,14 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - create_metadata_for_llava, - dummy_mm_kwargs_for_llava, + LlavaProcessor, get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava, - dummy_mm_kwargs_for_llava) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 953b89f1842af..65c6bd07bfff0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,10 +22,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.processing import (InputProcessingContext, +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, ModalityProcessingMetadata, MultiModalProcessingMetadata, - MultiModalProcessor, PromptReplacement) + PromptReplacement) from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -163,7 +164,13 @@ def get_repl_count( } -class LlavaProcessor(MultiModalProcessor): +class LlavaProcessor(BaseMultiModalProcessor): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), + ) def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): if getattr(hf_processor, "__is_patched__", False): @@ -193,7 +200,30 @@ def _get_dummy_mm_kwargs( self, mm_counts: Mapping[str, int], ) -> MultiModalKwargs: - return dummy_mm_kwargs_for_llava(self.ctx, mm_counts) + hf_config = self.ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + if isinstance(vision_config, CLIPVisionConfig): + data = dummy_image_for_clip(vision_config, num_images) + elif isinstance(vision_config, SiglipVisionConfig): + data = dummy_image_for_siglip(vision_config, num_images) + elif isinstance(vision_config, PixtralVisionConfig): + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + hf_processor = self._get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], + return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) class LlavaLikeConfig(Protocol): @@ -277,10 +307,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor( - ctx=ctx, - metadata=create_metadata_for_llava(ctx), -)) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -559,3 +586,28 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class MantisProcessor(LlavaProcessor): + + def _get_hf_processor(self) -> ProcessorMixin: + try: + from mantis.models.mllava import MLlavaProcessor + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "You need to `pip install " + "git+https://github.com/TIGER-AI-Lab/Mantis.git` " + "to use this model") from exc + + processor = MLlavaProcessor.from_pretrained( + self.ctx.model_config.tokenizer) + assert isinstance(processor, ProcessorMixin) + return processor + + +# To use this model, please use +# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) +class MantisForConditionalGeneration(LlavaForConditionalGeneration): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c66fbce018a62..e69596aa915b5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -152,6 +152,7 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 4a1737991534f..c3a95d60e6fe6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -529,9 +529,9 @@ def iter_placeholders( yield placeholder -class MultiModalProcessor(ABC): +class BaseMultiModalProcessor(ABC): """ - Helper class to process multi-modal inputs to be used in vLLM. + Abstract base class to process multi-modal inputs to be used in vLLM. """ def __init__( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f51da8972d15b..6ab6c0fe2f12e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,7 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessingMetadata, MultiModalProcessor +from .processing import BaseMultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -26,7 +26,7 @@ N = TypeVar("N", bound=Type[nn.Module]) MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - MultiModalProcessor] + BaseMultiModalProcessor] """ Constructs a :class:`MultiModalProcessor` instance from the context. @@ -311,41 +311,6 @@ def wrapper(model_cls: N) -> N: return wrapper - def register_processor_by_metadata( - self, - metadata_factory: Callable[[InputProcessingContext], - MultiModalProcessingMetadata], - get_dummy_mm_kwargs: Callable[ - [InputProcessingContext, Mapping[str, int]], MultiModalKwargs], - ): - """ - Convenience method to register a multi-modal processor to a model class - according to a function that constructs its metadata. - - When the model receives multi-modal data, the provided function is - invoked to transform the data into a dictionary of model inputs. - - See also: - - :ref:`input_processing_pipeline` - - :ref:`enabling_multimodal_inputs` - """ - - class ConcreteMultiModalProcessor(MultiModalProcessor): - - def _get_dummy_mm_kwargs( - self, - mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: - return get_dummy_mm_kwargs(self.ctx, mm_counts) - - def factory(ctx: InputProcessingContext): - return ConcreteMultiModalProcessor( - ctx=ctx, - metadata=metadata_factory(ctx), - ) - - return self.register_processor(factory) - def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. @@ -360,7 +325,7 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> MultiModalProcessor: + ) -> BaseMultiModalProcessor: """ Create a multi-modal processor for a specific model and tokenizer. """