diff --git a/docs/source/conf.py b/docs/source/conf.py index e9d9ac68c9560..7c22b1df49b81 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -162,6 +162,7 @@ def linkcode_resolve(domain, info): # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ + "blake3", "compressed_tensors", "cpuinfo", "cv2", @@ -178,7 +179,7 @@ def linkcode_resolve(domain, info): "tensorizer", "pynvml", "outlines", - "xgrammar," + "xgrammar", "librosa", "soundfile", "gguf", diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index fb02627eb22bd..8c77e7840cb50 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,11 +1,12 @@ import functools from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple, +from typing import (TYPE_CHECKING, Any, Callable, Literal, Mapping, NamedTuple, Optional, Protocol, Union) from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers.models.whisper import WhisperFeatureExtractor from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger @@ -111,6 +112,39 @@ def get_hf_processor( return hf_processor + def get_modality_processor( + self, + hf_processor: ProcessorMixin, + modality_data_key: Literal["text", "images", "videos", "audios"], + ) -> Callable[..., BatchFeature]: + """ + Get the HuggingFace modality-specific processor which is + a child of a :class:`transformers.ProcessorMixin`, identified by + the corresponding keyword argument in its `__call__` method. + """ + if modality_data_key == "text": + attributes = ["tokenizer"] + elif modality_data_key == "images": + attributes = ["image_processor"] + elif modality_data_key == "videos": + attributes = ["video_processor", "image_processor"] + elif modality_data_key == "audios": + attributes = ["audio_processor", "feature_extractor"] + else: + assert_never(modality_data_key) + + modality_processor = next( + (getattr(hf_processor, attr) + for attr in attributes if hasattr(hf_processor, attr)), + None, + ) + if modality_processor is None: + raise AttributeError( + f"Cannot find HuggingFace processor for {modality_data_key} " + f"inside {type(hf_processor)}") + + return modality_processor + @dataclass(frozen=True) class InputProcessingContext(InputContext): @@ -131,34 +165,39 @@ def get_hf_processor( def call_hf_processor( self, - hf_processor: ProcessorMixin, - prompt: str, - processor_data: Mapping[str, object], - inference_kwargs: Mapping[str, object], + hf_processor: Union[ProcessorMixin, Callable[..., BatchFeature]], + data: Mapping[str, object], + kwargs: Optional[Mapping[str, object]] = None, ) -> BatchFeature: assert callable(hf_processor) + if kwargs is None: + kwargs = {} + base_kwargs = self.model_config.mm_processor_kwargs if base_kwargs is None: base_kwargs = {} merged_kwargs = resolve_mm_processor_kwargs( base_kwargs, - inference_kwargs, + kwargs, hf_processor, requires_kw_only=False, - allow_var_kwargs=True, + # Modality-specific processors should state each kwarg individually + allow_var_kwargs=isinstance(hf_processor, ProcessorMixin), ) + # WhisperFeatureExtractor accepts `raw_speech` + # but the parent HF processor accepts `audios` + # Making `audios` an alias of `raw_speech` simplifies the calling code + if (isinstance(hf_processor, WhisperFeatureExtractor) + and "raw_speech" not in data): + data = dict(data) + data["raw_speech"] = data.pop("audios") + try: - return hf_processor( - text=prompt, - **processor_data, - **merged_kwargs, - return_tensors="pt", - ) + return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: - data = dict(text=prompt, **processor_data) msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0662d90e79b92..f6e3cfdb3f561 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,4 @@ from functools import cached_property -from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -116,36 +115,35 @@ def get_max_llava_image_tokens(ctx: InputContext): class LlavaMultiModalProcessor(BaseMultiModalProcessor): - def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): - if getattr(hf_processor, "__is_patched__", False): - return # Already patched - - image_processor = hf_processor.image_processor # type: ignore - orig_preprocess = image_processor.preprocess - - def preprocess(__self, *args, **kwargs): - hf_inputs = orig_preprocess(*args, **kwargs) - hf_inputs["is_pixtral"] = torch.tensor(True) - return hf_inputs - - image_processor.preprocess = MethodType(preprocess, image_processor) + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: + return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) - hf_processor.__is_patched__ = True # type: ignore + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - hf_processor = self.ctx.get_hf_processor( - (LlavaProcessor, PixtralProcessor)) + images = mm_data.get("images", []) + assert isinstance(images, list) - if isinstance(hf_processor, PixtralProcessor): - self._patch_pixtral_processor(hf_processor) + is_pixtral = isinstance(self._get_hf_processor(), PixtralProcessor) + processed_outputs["is_pixtral"] = \ + torch.tensor([is_pixtral] * len(images)) - return hf_processor + return processed_outputs def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index @@ -218,7 +216,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) @@ -379,8 +377,8 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) + is_pixtral = kwargs.pop("is_pixtral", None) if pixel_values is None and image_embeds is None: return None diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index e2263f63f7bba..9e5760671d8df 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -306,11 +306,11 @@ def get_max_phi3v_image_tokens( *, num_crops: Optional[int] = None, ) -> int: - mm_processor_kwargs = {} + hf_mm_kwargs = {} if num_crops: - mm_processor_kwargs["num_crops"] = num_crops + hf_mm_kwargs["num_crops"] = num_crops - processor = ctx.get_hf_processor(**mm_processor_kwargs) + processor = ctx.get_hf_processor(**hf_mm_kwargs) return processor.calc_num_image_tokens_from_image_size( width=MAX_IMAGE_FEATURE_SIZE_WIDTH, @@ -331,16 +331,14 @@ def _get_hf_processor( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, @@ -356,7 +354,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -401,7 +399,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="".join(image_tokens[:num_images]), mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 63d1374ab4092..baf955f6b515d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -225,7 +225,7 @@ def __init__( d_model: int, n_head: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +266,7 @@ def __init__( layers: int, heads: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6259166a7fc57..7d52737400a40 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,7 +26,7 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.qwen2_audio import (Qwen2AudioConfig, Qwen2AudioEncoder, Qwen2AudioProcessor) @@ -88,13 +88,20 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): - def _get_hf_processor(self) -> Qwen2AudioProcessor: + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + padding: Optional[str] = None, + ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -102,42 +109,52 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if audios: - processor_data["audios"] = audios + mm_data["audios"] = audios feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, + # When fine-grained caching is applied, + # the individual processors are called separately. + return_attention_mask=True, + padding="max_length", ) else: # NOTE: WhisperFeatureExtractor cannot handle empty list of audios pass - return super()._call_hf_processor( - hf_processor, + processed_outputs = super()._call_hf_processor( prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) + # When fine-grained caching is applied, + # the individual processors are called separately. + if "attention_mask" in processed_outputs: + processed_outputs["feature_attention_mask"] = \ + processed_outputs.pop("attention_mask") + + return processed_outputs + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) placeholder = hf_config.audio_token_index @@ -175,7 +192,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="<|AUDIO|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b38ea923f0bf1..60e550f77c054 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) import torch @@ -229,9 +229,9 @@ class Qwen2VisionAttention(nn.Module): def __init__( self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, + embed_dim: int, + num_heads: int, + projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -264,7 +264,7 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -347,7 +347,7 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -384,7 +384,7 @@ def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() @@ -392,8 +392,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_chans, + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, @@ -413,7 +413,7 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -489,15 +489,15 @@ def __init__( ) -> None: super().__init__() - patch_size: int = vision_config.patch_size - temporal_patch_size: int = vision_config.temporal_patch_size - spatial_merge_size: int = vision_config.spatial_merge_size - in_chans: int = vision_config.in_chans - hidden_size: int = vision_config.hidden_size - embed_dim: int = vision_config.embed_dim - depth: int = vision_config.depth - num_heads: int = vision_config.num_heads - mlp_ratio: float = vision_config.mlp_ratio + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads @@ -506,7 +506,7 @@ def __init__( self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, - in_chans=in_chans, + in_channels=in_channels, embed_dim=embed_dim, ) @@ -592,7 +592,7 @@ def forward( return x def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -600,7 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params = set[str]() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -784,7 +784,7 @@ def _get_hf_processor( return hf_processor - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -813,11 +813,31 @@ def _get_processor_data( return processor_data, passthrough_data + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + # Remove the extra dimension + if (not self.ctx.model_config.disable_mm_preprocessor_cache + and "pixel_values" in processed_outputs): + processed_outputs["pixel_values"] = \ + processed_outputs["pixel_values"].squeeze(0) + + return processed_outputs + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_processor = _get_image_processor(hf_processor) @@ -869,7 +889,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) @@ -945,9 +965,7 @@ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): return None return quant_config - def _validate_and_reshape_mm_tensor(self, - mm_input: Union[torch.Tensor, - List[torch.Tensor]], + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -957,7 +975,8 @@ def _validate_and_reshape_mm_tensor(self, return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}") + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) @@ -1189,7 +1208,7 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + torch.Tensor]]) -> set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c60b208c3d27d..678b1c10cb86a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -72,11 +72,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> ProcessorMixin: + return self.ctx.get_hf_processor() + def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -84,33 +92,31 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if not audios: return super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) - # Already resampled by _get_processor_data + # Already resampled by _get_hf_mm_data assert is_list_of(audios, np.ndarray) # Ultravox processor doesn't support multiple inputs, @@ -119,13 +125,12 @@ def _call_hf_processor( shared_outputs = {} for audio in audios: # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**processor_data, audio=audio) + item_processor_data = dict(**mm_data, audio=audio) item_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=item_processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) audio_features.append(item_outputs.pop("audio_values")[0]) @@ -143,7 +148,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() placeholder = hf_processor.audio_token_replacement # type: ignore @@ -175,7 +180,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="<|audio|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 138cc6a44c11a..e8f1a8f7bd228 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1,6 +1,6 @@ from collections import UserDict, defaultdict -from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, - TypedDict, TypeVar, Union, cast, final) +from collections.abc import Mapping, Sequence +from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final import numpy as np import torch @@ -44,7 +44,7 @@ """ # yapf: enable -MultiModalData: TypeAlias = Union[_T, List[_T]] +MultiModalData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -97,13 +97,13 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, - Tuple[torch.Tensor, ...]] +NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, + tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalKwargs.batch`. @@ -139,7 +139,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: # Only tensors (not lists) can be stacked. return stacked - tensors_ = cast(List[torch.Tensor], stacked) + tensors_ = cast(list[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. return tensors_ @@ -147,7 +147,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return torch.stack(tensors_) @staticmethod - def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -162,7 +162,7 @@ def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: # We need to consider the case where each item in the batch # contains different modalities (i.e. different keys). - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + item_lists = defaultdict[str, list[NestedTensors]](list) for inputs in inputs_list: for k, v in inputs.items(): @@ -207,16 +207,16 @@ class MultiModalInputsV2(TypedDict): prompt: str """The processed prompt text.""" - prompt_token_ids: List[int] + prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[List[int]] + token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: NotRequired[List[str]] + mm_hashes: NotRequired[list[str]] """The hashes of the multi-modal data.""" mm_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..1751873318523 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,13 +1,16 @@ +import pickle import re from abc import ABC, abstractmethod -from collections import UserDict +from collections import UserDict, defaultdict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field -from functools import lru_cache -from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union +from functools import lru_cache, partial +from typing import (Any, Literal, NamedTuple, Optional, Protocol, TypeVar, + Union, cast) import numpy as np import torch +from blake3 import blake3 from PIL.Image import Image from transformers import BatchFeature, ProcessorMixin from typing_extensions import assert_never @@ -15,12 +18,12 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of from .audio import resample_audio from .inputs import (AudioItem, ImageItem, MultiModalDataDict, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, - VideoItem) + MultiModalInputsV2, MultiModalKwargs, NestedTensors, + PlaceholderRange, VideoItem) logger = init_logger(__name__) @@ -584,10 +587,201 @@ def iter_placeholders( class ProcessorInputs(NamedTuple): - """Keyword arguments to :meth:`BaseMultiModalProcessor`""" + """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" prompt_text: str mm_data: MultiModalDataDict - mm_processor_kwargs: Mapping[str, object] + hf_mm_kwargs: Mapping[str, object] + + +class ProcessingCache: + + def __init__(self, capacity: int) -> None: + super().__init__() + + # DEBUG: Set to None to disable + self.debug_cache_hit_ratio_steps: Optional[int] = None + + self._fine_text_cache = LRUCache[str, BatchFeature](capacity) + self._fine_mm_cache = LRUCache[str, BatchFeature](capacity) + self._coarse_cache = LRUCache[str, BatchFeature](capacity) + + def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + cache_stats = cache.stat() + if cache_stats.total % steps == 0: + logger.debug("ProcessingCache: %s.hit_ratio = %.2f", name, + cache_stats.hit_ratio) + + def _hash_item(self, obj: object) -> bytes: + # Simple cases + if isinstance(obj, str): + return obj.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, Image): + return obj.tobytes() + + # Convertible to NumPy arrays + if isinstance(obj, torch.Tensor): + obj = obj.numpy() + if isinstance(obj, (int, float)): + obj = np.array(obj) + if isinstance(obj, np.ndarray): + return obj.tobytes() + + logger.warning( + "No serialization method found for %s. " + "Falling back to pickle.", type(obj)) + + return pickle.dumps(obj) + + def _iter_bytes_to_hash( + self, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for i, elem in enumerate(obj): + yield from self._iter_bytes_to_hash(f"{key}.{i}", elem) + elif isinstance(obj, dict): + for k, v in obj.items(): + yield from self._iter_bytes_to_hash(f"{key}.{k}", v) + else: + key_bytes = self._hash_item(key) + value_bytes = self._hash_item(obj) + yield key_bytes, value_bytes + + def _hash_kwargs(self, **kwargs: object) -> str: + hasher = blake3() + + for k, v in kwargs.items(): + for k_bytes, v_bytes in self._iter_bytes_to_hash(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) + + return hasher.hexdigest() + + def _cached_call_fine( + self, + ctx: InputProcessingContext, + hf_processor: ProcessorMixin, + text: str, + mm_data: Mapping[Literal["images", "videos", "audios"], list[Any]], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + self.maybe_log_cache_stats(self._fine_text_cache, "fine_text_cache") + + processed_text = self._fine_text_cache.get_or_put( + text, + default_factory=partial( + ctx.call_hf_processor, + ctx.get_modality_processor(hf_processor, "text"), + dict(text=text), + ), + ) + + processed_data = dict(**processed_text) + for data_key, items in mm_data.items(): + processed_modal_items = defaultdict[str, Union[ + list[torch.Tensor], list[NestedTensors]]](list) + + for item in items: + self.maybe_log_cache_stats(self._fine_mm_cache, + "fine_mm_cache") + + modal_item = cast(Mapping[str, object], {data_key: item}) + processed_modal_item = self._fine_mm_cache.get_or_put( + self._hash_kwargs(**modal_item, **mm_kwargs), + default_factory=partial( + ctx.call_hf_processor, + ctx.get_modality_processor(hf_processor, data_key), + modal_item, + mm_kwargs, + ), + ) + + for k, v in processed_modal_item.items(): + # Remove the extra batch dimension (if it exists) + # NOTE: v may be a list instead of a tensor + if len(v) == 1: + v = v[0] + + processed_modal_items[k].append(v) + + for k, vs in processed_modal_items.items(): + # Try to merge elements into a single tensor + if is_list_of(vs, torch.Tensor, check="all") and len(vs) > 0: + first_shape = vs[0].shape + if all(v.shape == first_shape for v in vs): + vs = torch.stack(vs) + + processed_data[k] = vs + + return BatchFeature(processed_data) + + def _cached_call_coarse( + self, + ctx: InputProcessingContext, + hf_processor: ProcessorMixin, + text: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache") + + processed_data = self._coarse_cache.get_or_put( + self._hash_kwargs(text=text, **mm_data, **mm_kwargs), + default_factory=partial( + ctx.call_hf_processor, + hf_processor, + dict(text=text, **mm_data), + mm_kwargs, + ), + ) + + # Shallow copy to avoid footgun when downstream methods + # mutate the returned dictionary (since the result is cached) + return BatchFeature(processed_data) # type: ignore[arg-type] + + def call_hf_processor( + self, + ctx: InputProcessingContext, + # Assumes that hf_processor has been initialized according to kwargs + hf_processor: ProcessorMixin, + text: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Try to cache each item separately to improve hit rate + extra_keys = mm_data.keys() - {"images", "videos", "audios"} + if (mm_data and not extra_keys + and all(isinstance(v, list) for v in mm_data.values())): + try: + return self._cached_call_fine( + ctx, + hf_processor, + text=text, + mm_data=mm_data, # type: ignore[arg-type] + mm_kwargs=mm_kwargs, + ) + except Exception: + logger.exception( + "Failed to apply processor on each item separately! " + "Falling back to coarse caching.", + stack_info=True, + ) + + return self._cached_call_coarse( + ctx, + hf_processor, + text=text, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) class BaseMultiModalProcessor(ABC): @@ -595,18 +789,24 @@ class BaseMultiModalProcessor(ABC): Abstract base class to process multi-modal inputs to be used in vLLM. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ) -> None: super().__init__() self.ctx = ctx + self.cache = cache def __call__( self, prompt: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, mm_processor_kwargs) + return self.apply(prompt, mm_data, hf_mm_kwargs) def _get_hf_processor(self) -> ProcessorMixin: """ @@ -629,7 +829,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: """ Given the original multi-modal items for this modality @@ -651,7 +851,7 @@ def _find_placeholders( return list( iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -679,39 +879,43 @@ def _get_processor_data( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - return self.ctx.call_hf_processor( - hf_processor, - prompt, - processor_data, - mm_processor_kwargs, + if self.cache is None: + return self.ctx.call_hf_processor( + self._get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + + return self.cache.call_hf_processor( + self.ctx, + self._get_hf_processor(**mm_kwargs), + text=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) def _apply_hf_processor( self, prompt: str, mm_items: MultiModalDataItems, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - hf_processor = self._get_hf_processor(**mm_processor_kwargs) - - processor_data, passthrough_data = self._get_processor_data(mm_items) + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) - hf_inputs = self._call_hf_processor( - hf_processor, + processed_data = self._call_hf_processor( prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=processor_data, + mm_kwargs=hf_mm_kwargs, ) - hf_inputs.update(passthrough_data) + processed_data.update(passthrough_data) - return hf_inputs + return processed_data def _bind_prompt_replacements( self, @@ -730,6 +934,10 @@ def _apply_prompt_replacements( tokenizer = self._get_tokenizer() token_matches = find_token_matches(token_ids, prompt_repls) + mm_match_counts = { + modality: len(matches) + for modality, matches in full_groupby_modality(token_matches) + } # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -742,8 +950,8 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= mm_item_counts[modality] - for modality, matches in full_groupby_modality(token_matches) + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() ): # yapf: disable token_ids = replace_token_matches( token_ids, @@ -775,7 +983,7 @@ def apply( self, prompt_text: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: """ Process multi-modal inputs to be used in vLLM. @@ -793,12 +1001,12 @@ def apply( mm_items = self._get_mm_items(mm_data) hf_inputs = self._apply_hf_processor(prompt_text, mm_items, - mm_processor_kwargs) + hf_mm_kwargs) prompt_ids, = hf_inputs.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs(hf_inputs) prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, - mm_processor_kwargs) + hf_mm_kwargs) all_prompt_repls = self._bind_prompt_replacements(prompt_repls) # If HF processor already inserts placeholder tokens, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 6cd79d414c978..17255359ab61c 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,10 +1,9 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, +from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, Sequence, Type, TypeVar) import torch.nn as nn -from typing_extensions import TypeAlias from vllm.inputs import InputProcessingContext from vllm.logger import init_logger @@ -15,7 +14,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor +from .processing import BaseMultiModalProcessor, ProcessingCache from .video import VideoPlugin if TYPE_CHECKING: @@ -25,13 +24,17 @@ N = TypeVar("N", bound=Type[nn.Module]) -MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - BaseMultiModalProcessor] -""" -Constructs a :class:`MultiModalProcessor` instance from the context. -The processing metadata should be derived from the context. -""" +class MultiModalProcessorFactory(Protocol): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ) -> BaseMultiModalProcessor: + ... class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): @@ -71,6 +74,8 @@ def __init__( self._limits_by_model = _MultiModalLimits() + self._processing_cache = ProcessingCache(256) # MM_CACHE_SIZE + def register_plugin(self, plugin: MultiModalPlugin) -> None: """ Register a multi-modal plugin so it can be recognized by vLLM. @@ -354,4 +359,7 @@ def create_processor( processor_factory = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - return processor_factory(ctx) + cache = (None if model_config.disable_mm_preprocessor_cache else + self._processing_cache) + + return processor_factory(ctx, cache=cache) diff --git a/vllm/utils.py b/vllm/utils.py index 1b90eca1cd6cc..10dac6898dd20 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,11 +22,11 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import OrderedDict, UserDict, defaultdict -from collections.abc import Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, Hashable, List, Literal, + Dict, Generator, Generic, List, Literal, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 @@ -191,13 +191,29 @@ def reset(self) -> None: self.counter = 0 +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + class LRUCache(Generic[_K, _V]): + """Note: This class is not thread safe!""" def __init__(self, capacity: int) -> None: self.cache = OrderedDict[_K, _V]() self.pinned_items = set[_K]() self.capacity = capacity + self._hits = 0 + self._total = 0 + def __contains__(self, key: _K) -> bool: return key in self.cache @@ -215,6 +231,9 @@ def __setitem__(self, key: _K, value: _V) -> None: def __delitem__(self, key: _K) -> None: self.pop(key) + def stat(self) -> CacheInfo: + return CacheInfo(hits=self._hits, total=self._total) + def touch(self, key: _K) -> None: self.cache.move_to_end(key) @@ -223,8 +242,12 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) + + self._hits += 1 else: value = default + + self._total += 1 return value def put(self, key: _K, value: _V) -> None: @@ -232,6 +255,19 @@ def put(self, key: _K, value: _V) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() + def get_or_put(self, key: _K, default_factory: Callable[[], _V]) -> _V: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + + self._hits += 1 + else: + value = default_factory() + self.put(key, value) + + self._total += 1 + return value + def pin(self, key: _K) -> None: """ Pins a key in the cache preventing it from being