From faa9b841753d8fac1cf85b7551e33d021bcb1953 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 17:42:17 +0000 Subject: [PATCH 01/26] Refactor multi-modal processor to support caching Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 18 +- vllm/model_executor/models/llava.py | 4 +- vllm/model_executor/models/phi3v.py | 20 +- vllm/model_executor/models/qwen2_audio.py | 30 ++- vllm/model_executor/models/qwen2_vl.py | 6 +- vllm/model_executor/models/ultravox.py | 37 ++-- vllm/multimodal/processing.py | 255 +++++++++++++++++++--- vllm/multimodal/registry.py | 25 ++- vllm/utils.py | 40 +++- 9 files changed, 328 insertions(+), 107 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index fb02627eb22bd..adb11989fc478 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -132,10 +132,12 @@ def get_hf_processor( def call_hf_processor( self, hf_processor: ProcessorMixin, - prompt: str, - processor_data: Mapping[str, object], - inference_kwargs: Mapping[str, object], + data: Mapping[str, object], + kwargs: Optional[Mapping[str, object]] = None, ) -> BatchFeature: + if kwargs is None: + kwargs = {} + assert callable(hf_processor) base_kwargs = self.model_config.mm_processor_kwargs @@ -144,21 +146,15 @@ def call_hf_processor( merged_kwargs = resolve_mm_processor_kwargs( base_kwargs, - inference_kwargs, + kwargs, hf_processor, requires_kw_only=False, allow_var_kwargs=True, ) try: - return hf_processor( - text=prompt, - **processor_data, - **merged_kwargs, - return_tensors="pt", - ) + return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: - data = dict(text=prompt, **processor_data) msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0662d90e79b92..2be5dcd4a88de 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -145,7 +145,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index @@ -218,7 +218,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index e2263f63f7bba..9e5760671d8df 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -306,11 +306,11 @@ def get_max_phi3v_image_tokens( *, num_crops: Optional[int] = None, ) -> int: - mm_processor_kwargs = {} + hf_mm_kwargs = {} if num_crops: - mm_processor_kwargs["num_crops"] = num_crops + hf_mm_kwargs["num_crops"] = num_crops - processor = ctx.get_hf_processor(**mm_processor_kwargs) + processor = ctx.get_hf_processor(**hf_mm_kwargs) return processor.calc_num_image_tokens_from_image_size( width=MAX_IMAGE_FEATURE_SIZE_WIDTH, @@ -331,16 +331,14 @@ def _get_hf_processor( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, @@ -356,7 +354,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -401,7 +399,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="".join(image_tokens[:num_images]), mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6259166a7fc57..28a0b16732c2b 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,7 +26,7 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.qwen2_audio import (Qwen2AudioConfig, Qwen2AudioEncoder, Qwen2AudioProcessor) @@ -94,7 +94,7 @@ def _get_hf_processor(self) -> Qwen2AudioProcessor: def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -102,24 +102,23 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if audios: - processor_data["audios"] = audios + mm_data["audios"] = audios feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) else: @@ -127,17 +126,16 @@ def _call_hf_processor( pass return super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) placeholder = hf_config.audio_token_index @@ -175,7 +173,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="<|AUDIO|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b38ea923f0bf1..1da2b79d3f2d4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -784,7 +784,7 @@ def _get_hf_processor( return hf_processor - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -817,7 +817,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_processor = _get_image_processor(hf_processor) @@ -869,7 +869,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c60b208c3d27d..c3588215ff0c2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -11,7 +11,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder @@ -76,7 +76,7 @@ def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -84,33 +84,31 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if not audios: return super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) - # Already resampled by _get_processor_data + # Already resampled by _get_hf_mm_data assert is_list_of(audios, np.ndarray) # Ultravox processor doesn't support multiple inputs, @@ -119,13 +117,12 @@ def _call_hf_processor( shared_outputs = {} for audio in audios: # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**processor_data, audio=audio) + item_processor_data = dict(**mm_data, audio=audio) item_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=item_processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) audio_features.append(item_outputs.pop("audio_values")[0]) @@ -143,7 +140,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() placeholder = hf_processor.audio_token_replacement # type: ignore @@ -175,7 +172,7 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="<|audio|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, + hf_mm_kwargs={}, ) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..06345c8cfa306 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,13 +1,14 @@ import re from abc import ABC, abstractmethod -from collections import UserDict +from collections import UserDict, defaultdict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field -from functools import lru_cache -from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union +from functools import lru_cache, partial +from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union, cast import numpy as np import torch +from blake3 import blake3 from PIL.Image import Image from transformers import BatchFeature, ProcessorMixin from typing_extensions import assert_never @@ -15,7 +16,7 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of from .audio import resample_audio from .inputs import (AudioItem, ImageItem, MultiModalDataDict, @@ -587,7 +588,194 @@ class ProcessorInputs(NamedTuple): """Keyword arguments to :meth:`BaseMultiModalProcessor`""" prompt_text: str mm_data: MultiModalDataDict - mm_processor_kwargs: Mapping[str, object] + hf_mm_kwargs: Mapping[str, object] + + +class ProcessingCache: + + def __init__(self, capacity: int) -> None: + super().__init__() + + # DEBUG: Set to None to disable + self.debug_cache_hit_ratio_steps: Optional[int] = None + + self._text_cache = LRUCache[str, BatchFeature](capacity) + self._mm_cache = LRUCache[str, BatchFeature](capacity) + self._coarse_cache = LRUCache[str, BatchFeature](capacity) + + def maybe_log_text_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + text_cache_stats = self._text_cache.stat() + if text_cache_stats.total % steps == 0: + logger.debug("ProcessingCache: text_cache.hit_ratio = %.2f", + text_cache_stats.hit_ratio) + + def maybe_log_mm_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + mm_cache_stats = self._mm_cache.stat() + if mm_cache_stats.total % steps == 0: + logger.debug("ProcessingCache: mm_cache.hit_ratio = %.2f", + mm_cache_stats.hit_ratio) + + def maybe_log_coarse_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + coarse_cache_stats = self._mm_cache.stat() + if coarse_cache_stats.total % steps == 0: + logger.debug("ProcessingCache: coarse_cache.hit_ratio = %.2f", + coarse_cache_stats.hit_ratio) + + def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for elem in obj: + yield from self._iter_bytes_to_hash(key, elem) + return + if isinstance(obj, dict): + for k, v in obj.items(): + yield from self._iter_bytes_to_hash(f"{key}_{k}", v) + return + + # Simple cases + if isinstance(obj, str): + yield key.encode("utf-8") + yield obj.encode("utf-8") + return + if isinstance(obj, bytes): + yield key.encode("utf-8") + yield obj + return + if isinstance(obj, Image): + yield key.encode("utf-8") + yield obj.tobytes() + return + + # Convertible to NumPy arrays + if isinstance(obj, torch.Tensor): + obj = obj.numpy() + if isinstance(obj, (int, float)): + obj = np.array(obj) + if isinstance(obj, np.ndarray): + yield key.encode("utf-8") + yield obj.tobytes() + return + + msg = f"Unable to hash object of type {type(obj)}" + raise NotImplementedError(msg) + + def _hash_kwargs(self, **kwargs: object) -> str: + hasher = blake3() + + for k, v in kwargs.items(): + for item_bytes in self._iter_bytes_to_hash(k, v): + hasher.update(item_bytes) + + return hasher.hexdigest() + + def _call_cache_fine( + self, + ctx: InputProcessingContext, + hf_processor: ProcessorMixin, + prompt: str, + mm_data: Mapping[str, list[object]], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_mm_items = defaultdict[str, list[torch.Tensor]]() + + num_items = len(next(iter(mm_data.values()))) + for idx in range(num_items): + mm_item = {k: [v[idx]] for k, v in mm_data.items()} + + self.maybe_log_mm_cache_stats() + + processed_mm_item = self._mm_cache.get_or_put( + self._hash_kwargs(**mm_item, **mm_kwargs), + default_factory=partial( + ctx.call_hf_processor, + hf_processor, + mm_item, + mm_kwargs, + ), + ) + + for k, v in processed_mm_item.items(): + processed_mm_items[k].append(v) + + # NOTE: Some processors do not accept mm-only input, in which case + # we have to fallback to processing `prompt` and `mm_data` together + # Therefore, we place the text processing last to avoid redundant + # computation + self.maybe_log_text_cache_stats() + + processed_text = self._text_cache.get_or_put( + prompt, + default_factory=partial( + ctx.call_hf_processor, + hf_processor, + dict(text=prompt), + ), + ) + + processed_data = dict(**processed_text, **processed_mm_items) + return BatchFeature(processed_data) + + def _call_cache_coarse( + self, + ctx: InputProcessingContext, + hf_processor: ProcessorMixin, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + self.maybe_log_coarse_cache_stats() + + return self._coarse_cache.get_or_put( + self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs), + default_factory=partial( + ctx.call_hf_processor, + hf_processor, + dict(text=prompt, **mm_data), + mm_kwargs, + ), + ) + + def call_hf_processor( + self, + ctx: InputProcessingContext, + # Assumes that hf_processor has been initialized according to kwargs + hf_processor: ProcessorMixin, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Try to cache each item separately to improve hit rate + if mm_data and all(isinstance(v, list) for v in mm_data.values()): + try: + return self._call_cache_fine( + ctx, + hf_processor, + prompt, + cast(Mapping[str, list[object]], mm_data), + mm_kwargs, + ) + except Exception: + pass + + return self._call_cache_coarse( + ctx, + hf_processor, + prompt, + mm_data, + mm_kwargs, + ) class BaseMultiModalProcessor(ABC): @@ -595,18 +783,24 @@ class BaseMultiModalProcessor(ABC): Abstract base class to process multi-modal inputs to be used in vLLM. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__( + self, + ctx: InputProcessingContext, + *, + cache: ProcessingCache, + ) -> None: super().__init__() self.ctx = ctx + self.cache = cache def __call__( self, prompt: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, mm_processor_kwargs) + return self.apply(prompt, mm_data, hf_mm_kwargs) def _get_hf_processor(self) -> ProcessorMixin: """ @@ -629,7 +823,7 @@ def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: """ Given the original multi-modal items for this modality @@ -651,7 +845,7 @@ def _find_placeholders( return list( iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -679,39 +873,36 @@ def _get_processor_data( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - return self.ctx.call_hf_processor( - hf_processor, + return self.cache.call_hf_processor( + self.ctx, + self._get_hf_processor(**mm_kwargs), prompt, - processor_data, - mm_processor_kwargs, + mm_data, + mm_kwargs, ) def _apply_hf_processor( self, prompt: str, mm_items: MultiModalDataItems, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - hf_processor = self._get_hf_processor(**mm_processor_kwargs) - - processor_data, passthrough_data = self._get_processor_data(mm_items) + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) - hf_inputs = self._call_hf_processor( - hf_processor, + processed_data = self._call_hf_processor( prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=processor_data, + mm_kwargs=hf_mm_kwargs, ) - hf_inputs.update(passthrough_data) + processed_data.update(passthrough_data) - return hf_inputs + return processed_data def _bind_prompt_replacements( self, @@ -775,7 +966,7 @@ def apply( self, prompt_text: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: """ Process multi-modal inputs to be used in vLLM. @@ -793,12 +984,12 @@ def apply( mm_items = self._get_mm_items(mm_data) hf_inputs = self._apply_hf_processor(prompt_text, mm_items, - mm_processor_kwargs) + hf_mm_kwargs) prompt_ids, = hf_inputs.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs(hf_inputs) prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, - mm_processor_kwargs) + hf_mm_kwargs) all_prompt_repls = self._bind_prompt_replacements(prompt_repls) # If HF processor already inserts placeholder tokens, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 6cd79d414c978..dcbe941b24b0d 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,10 +1,9 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, +from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, Sequence, Type, TypeVar) import torch.nn as nn -from typing_extensions import TypeAlias from vllm.inputs import InputProcessingContext from vllm.logger import init_logger @@ -15,7 +14,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor +from .processing import BaseMultiModalProcessor, ProcessingCache from .video import VideoPlugin if TYPE_CHECKING: @@ -25,13 +24,17 @@ N = TypeVar("N", bound=Type[nn.Module]) -MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - BaseMultiModalProcessor] -""" -Constructs a :class:`MultiModalProcessor` instance from the context. -The processing metadata should be derived from the context. -""" +class MultiModalProcessorFactory(Protocol): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + ctx: InputProcessingContext, + *, + cache: ProcessingCache, + ) -> BaseMultiModalProcessor: + ... class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): @@ -71,6 +74,8 @@ def __init__( self._limits_by_model = _MultiModalLimits() + self._processing_cache = ProcessingCache(256) # MM_CACHE_SIZE + def register_plugin(self, plugin: MultiModalPlugin) -> None: """ Register a multi-modal plugin so it can be recognized by vLLM. @@ -354,4 +359,4 @@ def create_processor( processor_factory = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - return processor_factory(ctx) + return processor_factory(ctx, cache=self._processing_cache) diff --git a/vllm/utils.py b/vllm/utils.py index 3934903385ad4..8d9901c343067 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,11 +22,11 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import OrderedDict, UserDict, defaultdict -from collections.abc import Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, Hashable, List, Literal, + Dict, Generator, Generic, List, Literal, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 @@ -191,13 +191,29 @@ def reset(self) -> None: self.counter = 0 +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + class LRUCache(Generic[_K, _V]): + """Note: This class is not thread safe!""" def __init__(self, capacity: int) -> None: self.cache = OrderedDict[_K, _V]() self.pinned_items = set[_K]() self.capacity = capacity + self._hits = 0 + self._total = 0 + def __contains__(self, key: _K) -> bool: return key in self.cache @@ -215,6 +231,9 @@ def __setitem__(self, key: _K, value: _V) -> None: def __delitem__(self, key: _K) -> None: self.pop(key) + def stat(self) -> CacheInfo: + return CacheInfo(hits=self._hits, total=self._total) + def touch(self, key: _K) -> None: self.cache.move_to_end(key) @@ -223,8 +242,12 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) + + self._hits += 1 else: value = default + + self._total += 1 return value def put(self, key: _K, value: _V) -> None: @@ -232,6 +255,19 @@ def put(self, key: _K, value: _V) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() + def get_or_put(self, key: _K, default_factory: Callable[[], _V]) -> _V: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + + self._hits += 1 + else: + value = default_factory() + self.put(key, value) + + self._total += 1 + return value + def pin(self, key: _K) -> None: """ Pins a key in the cache preventing it from being From 9711a1556560ac1bc53985fc271e95fb2c762c75 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 17:48:55 +0000 Subject: [PATCH 02/26] Clean up Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 38 +++++++++-------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 06345c8cfa306..7748fade0fbd7 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -603,35 +603,15 @@ def __init__(self, capacity: int) -> None: self._mm_cache = LRUCache[str, BatchFeature](capacity) self._coarse_cache = LRUCache[str, BatchFeature](capacity) - def maybe_log_text_cache_stats(self) -> None: + def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: steps = self.debug_cache_hit_ratio_steps if not steps: return - text_cache_stats = self._text_cache.stat() - if text_cache_stats.total % steps == 0: - logger.debug("ProcessingCache: text_cache.hit_ratio = %.2f", - text_cache_stats.hit_ratio) - - def maybe_log_mm_cache_stats(self) -> None: - steps = self.debug_cache_hit_ratio_steps - if not steps: - return - - mm_cache_stats = self._mm_cache.stat() - if mm_cache_stats.total % steps == 0: - logger.debug("ProcessingCache: mm_cache.hit_ratio = %.2f", - mm_cache_stats.hit_ratio) - - def maybe_log_coarse_cache_stats(self) -> None: - steps = self.debug_cache_hit_ratio_steps - if not steps: - return - - coarse_cache_stats = self._mm_cache.stat() - if coarse_cache_stats.total % steps == 0: - logger.debug("ProcessingCache: coarse_cache.hit_ratio = %.2f", - coarse_cache_stats.hit_ratio) + cache_stats = cache.stat() + if cache_stats.total % steps == 0: + logger.debug("ProcessingCache: %s.hit_ratio = %.2f", + name, cache_stats.hit_ratio) def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: # Recursive cases @@ -694,7 +674,7 @@ def _call_cache_fine( for idx in range(num_items): mm_item = {k: [v[idx]] for k, v in mm_data.items()} - self.maybe_log_mm_cache_stats() + self.maybe_log_cache_stats(self._mm_cache, "mm_cache") processed_mm_item = self._mm_cache.get_or_put( self._hash_kwargs(**mm_item, **mm_kwargs), @@ -713,7 +693,7 @@ def _call_cache_fine( # we have to fallback to processing `prompt` and `mm_data` together # Therefore, we place the text processing last to avoid redundant # computation - self.maybe_log_text_cache_stats() + self.maybe_log_cache_stats(self._text_cache, "text_cache") processed_text = self._text_cache.get_or_put( prompt, @@ -735,7 +715,7 @@ def _call_cache_coarse( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - self.maybe_log_coarse_cache_stats() + self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache") return self._coarse_cache.get_or_put( self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs), @@ -767,7 +747,7 @@ def call_hf_processor( mm_kwargs, ) except Exception: - pass + pass # See NOTE in `_call_cache_fine` return self._call_cache_coarse( ctx, From 29e3fcdcd96aa82b12fe78ba7dfa322df1892dcf Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:02:12 +0000 Subject: [PATCH 03/26] Fix cached result being mutated Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7748fade0fbd7..29544c4407612 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -610,8 +610,8 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: cache_stats = cache.stat() if cache_stats.total % steps == 0: - logger.debug("ProcessingCache: %s.hit_ratio = %.2f", - name, cache_stats.hit_ratio) + logger.debug("ProcessingCache: %s.hit_ratio = %.2f", name, + cache_stats.hit_ratio) def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: # Recursive cases @@ -717,7 +717,7 @@ def _call_cache_coarse( ) -> BatchFeature: self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache") - return self._coarse_cache.get_or_put( + processed_data = self._coarse_cache.get_or_put( self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs), default_factory=partial( ctx.call_hf_processor, @@ -727,6 +727,10 @@ def _call_cache_coarse( ), ) + # Shallow copy to avoid footgun when downstream methods + # mutate the returned dictionary (since the result is cached) + return BatchFeature(processed_data) # type: ignore + def call_hf_processor( self, ctx: InputProcessingContext, From ab64e85e79649987d3dc1005ac09b16a5ad53ba2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:03:13 +0000 Subject: [PATCH 04/26] Rename Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 29544c4407612..b91aacc84b398 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -660,7 +660,7 @@ def _hash_kwargs(self, **kwargs: object) -> str: return hasher.hexdigest() - def _call_cache_fine( + def _cached_call_fine( self, ctx: InputProcessingContext, hf_processor: ProcessorMixin, @@ -707,7 +707,7 @@ def _call_cache_fine( processed_data = dict(**processed_text, **processed_mm_items) return BatchFeature(processed_data) - def _call_cache_coarse( + def _cached_call_coarse( self, ctx: InputProcessingContext, hf_processor: ProcessorMixin, @@ -743,7 +743,7 @@ def call_hf_processor( # Try to cache each item separately to improve hit rate if mm_data and all(isinstance(v, list) for v in mm_data.values()): try: - return self._call_cache_fine( + return self._cached_call_fine( ctx, hf_processor, prompt, @@ -751,9 +751,9 @@ def call_hf_processor( mm_kwargs, ) except Exception: - pass # See NOTE in `_call_cache_fine` + pass # See NOTE in `_cached_call_fine` - return self._call_cache_coarse( + return self._cached_call_coarse( ctx, hf_processor, prompt, From 81215a22425e1ee67adcbcb5fcf5cecd600e7846 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:07:16 +0000 Subject: [PATCH 05/26] Fix docs Signed-off-by: DarkLight1337 --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index e9d9ac68c9560..91cf31fa2a579 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -162,6 +162,7 @@ def linkcode_resolve(domain, info): # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ + "blake3", "compressed_tensors", "cpuinfo", "cv2", From cf52b3bb58900b35c55e366e37c64e02490663db Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:08:43 +0000 Subject: [PATCH 06/26] Fix a typo Signed-off-by: DarkLight1337 --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 91cf31fa2a579..7c22b1df49b81 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -179,7 +179,7 @@ def linkcode_resolve(domain, info): "tensorizer", "pynvml", "outlines", - "xgrammar," + "xgrammar", "librosa", "soundfile", "gguf", From a4a8eb9e9b4927d2ae1db2e9e4d113127f054445 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:12:44 +0000 Subject: [PATCH 07/26] Fix unhandled sampling rate in initialization Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_audio.py | 6 +++++- vllm/model_executor/models/ultravox.py | 9 ++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 28a0b16732c2b..22f52529168c8 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -88,7 +88,11 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): - def _get_hf_processor(self) -> Qwen2AudioProcessor: + def _get_hf_processor( + self, + *, + sampling_rate: Optional[int] = None, # Ignored in initialization + ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) def _get_feature_extractor(self) -> WhisperFeatureExtractor: diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c3588215ff0c2..41939fda859b6 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -11,7 +11,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F -from transformers import BatchFeature +from transformers import BatchFeature, ProcessorMixin from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder @@ -72,6 +72,13 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + def _get_hf_processor( + self, + *, + sampling_rate: Optional[int] = None, # Ignored in initialization + ) -> ProcessorMixin: + return self.ctx.get_hf_processor() + def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore From c48f7c5dcf8705aabea713e8a5eeafbef07804ed Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:17:16 +0000 Subject: [PATCH 08/26] format Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_audio.py | 3 ++- vllm/model_executor/models/ultravox.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 22f52529168c8..4031a7a7626b2 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -91,7 +91,8 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): def _get_hf_processor( self, *, - sampling_rate: Optional[int] = None, # Ignored in initialization + # Ignored in initialization + sampling_rate: Optional[int] = None, ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 41939fda859b6..678b1c10cb86a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -75,7 +75,8 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): def _get_hf_processor( self, *, - sampling_rate: Optional[int] = None, # Ignored in initialization + # Ignored in initialization + sampling_rate: Optional[int] = None, ) -> ProcessorMixin: return self.ctx.get_hf_processor() From b84ff4299332c612201d9e6b426d5485150f65c8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:18:19 +0000 Subject: [PATCH 09/26] Change the delimiter Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index b91aacc84b398..306b7f11f413b 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -621,7 +621,7 @@ def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: return if isinstance(obj, dict): for k, v in obj.items(): - yield from self._iter_bytes_to_hash(f"{key}_{k}", v) + yield from self._iter_bytes_to_hash(f"{key}.{k}", v) return # Simple cases From c3f1bde10fb232057798456aa948b8d4f937a5d0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:24:18 +0000 Subject: [PATCH 10/26] Fix extra dimension Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 306b7f11f413b..78a031032337a 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -687,7 +687,7 @@ def _cached_call_fine( ) for k, v in processed_mm_item.items(): - processed_mm_items[k].append(v) + processed_mm_items[k].append(v[0]) # NOTE: Some processors do not accept mm-only input, in which case # we have to fallback to processing `prompt` and `mm_data` together From 32e5197524c950a8102edb4c450d7a09b2ce21c2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Dec 2024 18:33:11 +0000 Subject: [PATCH 11/26] Update Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 78a031032337a..bb611a72dedad 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -599,8 +599,8 @@ def __init__(self, capacity: int) -> None: # DEBUG: Set to None to disable self.debug_cache_hit_ratio_steps: Optional[int] = None - self._text_cache = LRUCache[str, BatchFeature](capacity) - self._mm_cache = LRUCache[str, BatchFeature](capacity) + self._fine_text_cache = LRUCache[str, BatchFeature](capacity) + self._fine_mm_cache = LRUCache[str, BatchFeature](capacity) self._coarse_cache = LRUCache[str, BatchFeature](capacity) def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: @@ -674,9 +674,9 @@ def _cached_call_fine( for idx in range(num_items): mm_item = {k: [v[idx]] for k, v in mm_data.items()} - self.maybe_log_cache_stats(self._mm_cache, "mm_cache") + self.maybe_log_cache_stats(self._fine_mm_cache, "fine_mm_cache") - processed_mm_item = self._mm_cache.get_or_put( + processed_mm_item = self._fine_mm_cache.get_or_put( self._hash_kwargs(**mm_item, **mm_kwargs), default_factory=partial( ctx.call_hf_processor, @@ -687,15 +687,16 @@ def _cached_call_fine( ) for k, v in processed_mm_item.items(): + # Remove the extra batch dimension processed_mm_items[k].append(v[0]) - # NOTE: Some processors do not accept mm-only input, in which case - # we have to fallback to processing `prompt` and `mm_data` together - # Therefore, we place the text processing last to avoid redundant - # computation - self.maybe_log_cache_stats(self._text_cache, "text_cache") + # NOTE: Some processors (e.g. llava) do not accept mm-only input, + # in which case we have to fallback to processing `prompt` and `mm_data` + # together. Therefore, we place the text processing last to avoid + # redundant computation + self.maybe_log_cache_stats(self._fine_text_cache, "fine_text_cache") - processed_text = self._text_cache.get_or_put( + processed_text = self._fine_text_cache.get_or_put( prompt, default_factory=partial( ctx.call_hf_processor, @@ -751,7 +752,12 @@ def call_hf_processor( mm_kwargs, ) except Exception: - pass # See NOTE in `_cached_call_fine` + # Failures are expected; see NOTE in `_cached_call_fine` + logger.debug( + "Failed to apply processor on each item separately", + stack_info=True, + ) + pass return self._cached_call_coarse( ctx, From 7264d4e1647d6313727b6a7e4ee8cf1fd2509df7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 03:54:40 +0000 Subject: [PATCH 12/26] Use the inner processor to enable fine-grained caching Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 41 +++++++++++-- vllm/multimodal/processing.py | 106 +++++++++++++++++----------------- 2 files changed, 90 insertions(+), 57 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index adb11989fc478..23a1e351caca3 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,7 +1,7 @@ import functools from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple, +from typing import (TYPE_CHECKING, Any, Callable, Literal, Mapping, NamedTuple, Optional, Protocol, Union) from torch import nn @@ -111,6 +111,39 @@ def get_hf_processor( return hf_processor + def get_modality_processor( + self, + hf_processor: ProcessorMixin, + modality_data_key: Literal["text", "images", "videos", "audios"], + ) -> Callable[..., BatchFeature]: + """ + Get the HuggingFace modality-specific processor which is + a child of a :class:`transformers.ProcessorMixin`, identified by + the corresponding keyword argument in its `__call__` method. + """ + if modality_data_key == "text": + attributes = ["tokenizer"] + elif modality_data_key == "images": + attributes = ["image_processor"] + elif modality_data_key == "videos": + attributes = ["video_processor"] + elif modality_data_key == "audios": + attributes = ["audio_processor", "feature_extractor"] + else: + assert_never(modality_data_key) + + modality_processor = next( + (getattr(hf_processor, attr) + for attr in attributes if hasattr(hf_processor, attr)), + None, + ) + if modality_processor is None: + raise AttributeError( + f"Cannot found HuggingFace processor for " + f"{modality_data_key} inside {type(hf_processor)}") + + return modality_processor + @dataclass(frozen=True) class InputProcessingContext(InputContext): @@ -131,15 +164,15 @@ def get_hf_processor( def call_hf_processor( self, - hf_processor: ProcessorMixin, + hf_processor: Union[ProcessorMixin, Callable[..., BatchFeature]], data: Mapping[str, object], kwargs: Optional[Mapping[str, object]] = None, ) -> BatchFeature: + assert callable(hf_processor) + if kwargs is None: kwargs = {} - assert callable(hf_processor) - base_kwargs = self.model_config.mm_processor_kwargs if base_kwargs is None: base_kwargs = {} diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index bb611a72dedad..979104aaed8c5 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,7 +4,8 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache, partial -from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union, cast +from typing import (Any, Literal, NamedTuple, Optional, Protocol, TypeVar, + Union, cast) import numpy as np import torch @@ -616,8 +617,8 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: # Recursive cases if isinstance(obj, (list, tuple)): - for elem in obj: - yield from self._iter_bytes_to_hash(key, elem) + for i, elem in enumerate(obj): + yield from self._iter_bytes_to_hash(f"{key}.{i}", elem) return if isinstance(obj, dict): for k, v in obj.items(): @@ -664,66 +665,64 @@ def _cached_call_fine( self, ctx: InputProcessingContext, hf_processor: ProcessorMixin, - prompt: str, - mm_data: Mapping[str, list[object]], + text: str, + mm_data: Mapping[Literal["images", "videos", "audios"], list[Any]], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processed_mm_items = defaultdict[str, list[torch.Tensor]]() - - num_items = len(next(iter(mm_data.values()))) - for idx in range(num_items): - mm_item = {k: [v[idx]] for k, v in mm_data.items()} - - self.maybe_log_cache_stats(self._fine_mm_cache, "fine_mm_cache") - - processed_mm_item = self._fine_mm_cache.get_or_put( - self._hash_kwargs(**mm_item, **mm_kwargs), - default_factory=partial( - ctx.call_hf_processor, - hf_processor, - mm_item, - mm_kwargs, - ), - ) - - for k, v in processed_mm_item.items(): - # Remove the extra batch dimension - processed_mm_items[k].append(v[0]) - - # NOTE: Some processors (e.g. llava) do not accept mm-only input, - # in which case we have to fallback to processing `prompt` and `mm_data` - # together. Therefore, we place the text processing last to avoid - # redundant computation self.maybe_log_cache_stats(self._fine_text_cache, "fine_text_cache") processed_text = self._fine_text_cache.get_or_put( - prompt, + text, default_factory=partial( ctx.call_hf_processor, - hf_processor, - dict(text=prompt), + ctx.get_modality_processor(hf_processor, "text"), + dict(text=text), ), ) - processed_data = dict(**processed_text, **processed_mm_items) + processed_data = dict(**processed_text) + for data_key, items in mm_data.items(): + processed_modal_items = defaultdict[str, list[torch.Tensor]](list) + + for item in items: + self.maybe_log_cache_stats(self._fine_mm_cache, + "fine_mm_cache") + + modal_item = cast(Mapping[str, object], {data_key: item}) + processed_modal_item = self._fine_mm_cache.get_or_put( + self._hash_kwargs(**modal_item, **mm_kwargs), + default_factory=partial( + ctx.call_hf_processor, + ctx.get_modality_processor(hf_processor, data_key), + modal_item, + mm_kwargs, + ), + ) + + for k, v in processed_modal_item.items(): + # Remove the extra batch dimension + processed_modal_items[k].append(v[0]) + + processed_data.update(processed_modal_items) + return BatchFeature(processed_data) def _cached_call_coarse( self, ctx: InputProcessingContext, hf_processor: ProcessorMixin, - prompt: str, + text: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache") processed_data = self._coarse_cache.get_or_put( - self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs), + self._hash_kwargs(text=text, **mm_data, **mm_kwargs), default_factory=partial( ctx.call_hf_processor, hf_processor, - dict(text=prompt, **mm_data), + dict(text=text, **mm_data), mm_kwargs, ), ) @@ -737,34 +736,35 @@ def call_hf_processor( ctx: InputProcessingContext, # Assumes that hf_processor has been initialized according to kwargs hf_processor: ProcessorMixin, - prompt: str, + text: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: # Try to cache each item separately to improve hit rate - if mm_data and all(isinstance(v, list) for v in mm_data.values()): + extra_keys = mm_data.keys() - {"images", "videos", "audios"} + if (mm_data and not extra_keys + and all(isinstance(v, list) for v in mm_data.values())): try: return self._cached_call_fine( ctx, hf_processor, - prompt, - cast(Mapping[str, list[object]], mm_data), - mm_kwargs, + text=text, + mm_data=mm_data, # type: ignore[arg-type] + mm_kwargs=mm_kwargs, ) except Exception: - # Failures are expected; see NOTE in `_cached_call_fine` - logger.debug( - "Failed to apply processor on each item separately", + logger.exception( + "Failed to apply processor on each item separately! " + "Falling back to coarse caching.", stack_info=True, ) - pass return self._cached_call_coarse( ctx, hf_processor, - prompt, - mm_data, - mm_kwargs, + text=text, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) @@ -872,9 +872,9 @@ def _call_hf_processor( return self.cache.call_hf_processor( self.ctx, self._get_hf_processor(**mm_kwargs), - prompt, - mm_data, - mm_kwargs, + text=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) def _apply_hf_processor( From 02ea82951622647f599bb3fbec3acd3d631d36fd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 04:00:21 +0000 Subject: [PATCH 13/26] Make the cache optional Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 9 ++++++++- vllm/multimodal/registry.py | 7 +++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 979104aaed8c5..c63ebe1637324 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -777,7 +777,7 @@ def __init__( self, ctx: InputProcessingContext, *, - cache: ProcessingCache, + cache: Optional[ProcessingCache] = None, ) -> None: super().__init__() @@ -869,6 +869,13 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: + if self.cache is None: + return self.ctx.call_hf_processor( + self._get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + return self.cache.call_hf_processor( self.ctx, self._get_hf_processor(**mm_kwargs), diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index dcbe941b24b0d..17255359ab61c 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -32,7 +32,7 @@ def __call__( self, ctx: InputProcessingContext, *, - cache: ProcessingCache, + cache: Optional[ProcessingCache] = None, ) -> BaseMultiModalProcessor: ... @@ -359,4 +359,7 @@ def create_processor( processor_factory = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - return processor_factory(ctx, cache=self._processing_cache) + cache = (None if model_config.disable_mm_preprocessor_cache else + self._processing_cache) + + return processor_factory(ctx, cache=cache) From b981a9dbab1d7c1f870419766ee8bf5d52db2c53 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 04:08:34 +0000 Subject: [PATCH 14/26] Fix invalid kwargs being passed to tokenizer Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 23a1e351caca3..f99f5f424ab28 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -182,7 +182,8 @@ def call_hf_processor( kwargs, hf_processor, requires_kw_only=False, - allow_var_kwargs=True, + # Modality-specific processors should state each kwarg individually + allow_var_kwargs=isinstance(hf_processor, ProcessorMixin), ) try: From 5dde7d00ceede0166a3294661e9b8caf09d240e0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 04:25:15 +0000 Subject: [PATCH 15/26] Fix Phi3V prompt replacement Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c63ebe1637324..6f2cc7805ccba 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -918,6 +918,10 @@ def _apply_prompt_replacements( tokenizer = self._get_tokenizer() token_matches = find_token_matches(token_ids, prompt_repls) + mm_match_counts = { + modality: len(matches) + for modality, matches in full_groupby_modality(token_matches) + } # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -930,8 +934,8 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= mm_item_counts[modality] - for modality, matches in full_groupby_modality(token_matches) + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() ): # yapf: disable token_ids = replace_token_matches( token_ids, From 7339ab83e5c618e3c2eec86a08f767f6838cf33c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 04:27:17 +0000 Subject: [PATCH 16/26] Refine Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6f2cc7805ccba..d9b6037a343ad 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -729,7 +729,7 @@ def _cached_call_coarse( # Shallow copy to avoid footgun when downstream methods # mutate the returned dictionary (since the result is cached) - return BatchFeature(processed_data) # type: ignore + return BatchFeature(processed_data) # type: ignore[arg-type] def call_hf_processor( self, From 509411dd07a0c1d9dcba3a406099707b459a39f8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 04:56:29 +0000 Subject: [PATCH 17/26] Enable fine-grained caching for audio models Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 9 +++++++++ vllm/model_executor/models/qwen2_audio.py | 16 +++++++++++++++- vllm/multimodal/processing.py | 16 ++++++++++++---- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f99f5f424ab28..6855b4c017cfe 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -6,6 +6,7 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers.models.whisper import WhisperFeatureExtractor from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger @@ -186,6 +187,14 @@ def call_hf_processor( allow_var_kwargs=isinstance(hf_processor, ProcessorMixin), ) + # WhisperFeatureExtractor accepts `raw_speech` + # but the parent HF processor accepts `audios` + # Making `audios` an alias of `raw_speech` simplifies the calling code + if (isinstance(hf_processor, WhisperFeatureExtractor) + and "raw_speech" not in data): + data = dict(data) + data["raw_speech"] = data.pop("audios") + try: return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 4031a7a7626b2..7d52737400a40 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -93,6 +93,8 @@ def _get_hf_processor( *, # Ignored in initialization sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + padding: Optional[str] = None, ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) @@ -125,17 +127,29 @@ def _call_hf_processor( mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, + # When fine-grained caching is applied, + # the individual processors are called separately. + return_attention_mask=True, + padding="max_length", ) else: # NOTE: WhisperFeatureExtractor cannot handle empty list of audios pass - return super()._call_hf_processor( + processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) + # When fine-grained caching is applied, + # the individual processors are called separately. + if "attention_mask" in processed_outputs: + processed_outputs["feature_attention_mask"] = \ + processed_outputs.pop("attention_mask") + + return processed_outputs + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d9b6037a343ad..5c351d31fa91c 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -21,8 +21,8 @@ from .audio import resample_audio from .inputs import (AudioItem, ImageItem, MultiModalDataDict, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, - VideoItem) + MultiModalInputsV2, MultiModalKwargs, NestedTensors, + PlaceholderRange, VideoItem) logger = init_logger(__name__) @@ -682,7 +682,8 @@ def _cached_call_fine( processed_data = dict(**processed_text) for data_key, items in mm_data.items(): - processed_modal_items = defaultdict[str, list[torch.Tensor]](list) + processed_modal_items = defaultdict[str, Union[ + list[torch.Tensor], list[NestedTensors]]](list) for item in items: self.maybe_log_cache_stats(self._fine_mm_cache, @@ -703,7 +704,14 @@ def _cached_call_fine( # Remove the extra batch dimension processed_modal_items[k].append(v[0]) - processed_data.update(processed_modal_items) + for k, vs in processed_modal_items.items(): + # Try to merge elements into a single tensor + if is_list_of(vs, torch.Tensor, check="all") and len(vs) > 0: + first_shape = vs[0].shape + if all(v.shape == first_shape for v in vs): + vs = torch.stack(vs) + + processed_data[k] = vs return BatchFeature(processed_data) From c0454f52f2c7579c0c53fa69a207dc1c7844ee96 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 05:00:32 +0000 Subject: [PATCH 18/26] Add fallback Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5c351d31fa91c..753aa5394e410 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,3 +1,4 @@ +import pickle import re from abc import ABC, abstractmethod from collections import UserDict, defaultdict @@ -614,7 +615,11 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: logger.debug("ProcessingCache: %s.hit_ratio = %.2f", name, cache_stats.hit_ratio) - def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: + def _iter_bytes_to_hash( + self, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): @@ -625,18 +630,17 @@ def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: yield from self._iter_bytes_to_hash(f"{key}.{k}", v) return + key_bytes = key.encode("utf-8") + # Simple cases if isinstance(obj, str): - yield key.encode("utf-8") - yield obj.encode("utf-8") + yield key_bytes, obj.encode("utf-8") return if isinstance(obj, bytes): - yield key.encode("utf-8") - yield obj + yield key_bytes, obj return if isinstance(obj, Image): - yield key.encode("utf-8") - yield obj.tobytes() + yield key_bytes, obj.tobytes() return # Convertible to NumPy arrays @@ -645,19 +649,22 @@ def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]: if isinstance(obj, (int, float)): obj = np.array(obj) if isinstance(obj, np.ndarray): - yield key.encode("utf-8") - yield obj.tobytes() + yield key_bytes, obj.tobytes() return - msg = f"Unable to hash object of type {type(obj)}" - raise NotImplementedError(msg) + logger.warning( + "No serialization method found for %s. " + "Falling back to pickle.", type(obj)) + + yield key_bytes, pickle.dumps(obj) def _hash_kwargs(self, **kwargs: object) -> str: hasher = blake3() for k, v in kwargs.items(): - for item_bytes in self._iter_bytes_to_hash(k, v): - hasher.update(item_bytes) + for k_bytes, v_bytes in self._iter_bytes_to_hash(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) return hasher.hexdigest() From d50ef031c2849c3607e8438c38821dd911bf65c0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 05:10:37 +0000 Subject: [PATCH 19/26] Fix typo Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6855b4c017cfe..5bf44bd95acb9 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -140,8 +140,8 @@ def get_modality_processor( ) if modality_processor is None: raise AttributeError( - f"Cannot found HuggingFace processor for " - f"{modality_data_key} inside {type(hf_processor)}") + f"Cannot find HuggingFace processor for {modality_data_key} " + f"inside {type(hf_processor)}") return modality_processor From 81f7d617cfe2717974a0e95ac0be93917abfd3a4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 05:25:59 +0000 Subject: [PATCH 20/26] Fix video processor for Qwen2-VL Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 5bf44bd95acb9..8c77e7840cb50 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -127,7 +127,7 @@ def get_modality_processor( elif modality_data_key == "images": attributes = ["image_processor"] elif modality_data_key == "videos": - attributes = ["video_processor"] + attributes = ["video_processor", "image_processor"] elif modality_data_key == "audios": attributes = ["audio_processor", "feature_extractor"] else: From affbc5c09c5fd1c137b50cf84f5fcee125d147a4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 13:41:19 +0000 Subject: [PATCH 21/26] Fix a bunch of type errors Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen.py | 4 +-- vllm/model_executor/models/qwen2_vl.py | 46 +++++++++++++------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 63d1374ab4092..baf955f6b515d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -225,7 +225,7 @@ def __init__( d_model: int, n_head: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +266,7 @@ def __init__( layers: int, heads: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1da2b79d3f2d4..e39893cdf425c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) import torch @@ -229,9 +229,9 @@ class Qwen2VisionAttention(nn.Module): def __init__( self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, + embed_dim: int, + num_heads: int, + projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -264,7 +264,7 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -347,7 +347,7 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -384,7 +384,7 @@ def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() @@ -392,8 +392,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_chans, + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, @@ -413,7 +413,7 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -489,15 +489,15 @@ def __init__( ) -> None: super().__init__() - patch_size: int = vision_config.patch_size - temporal_patch_size: int = vision_config.temporal_patch_size - spatial_merge_size: int = vision_config.spatial_merge_size - in_chans: int = vision_config.in_chans - hidden_size: int = vision_config.hidden_size - embed_dim: int = vision_config.embed_dim - depth: int = vision_config.depth - num_heads: int = vision_config.num_heads - mlp_ratio: float = vision_config.mlp_ratio + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads @@ -506,7 +506,7 @@ def __init__( self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, - in_chans=in_chans, + in_channels=in_channels, embed_dim=embed_dim, ) @@ -592,7 +592,7 @@ def forward( return x def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -600,7 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params = set[str]() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -1189,7 +1189,7 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + torch.Tensor]]) -> set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", From b4ddfb15f1f33f52c552f95d29d45c0a464ecfa3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 14:30:42 +0000 Subject: [PATCH 22/26] Fix qwen2-vl Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 753aa5394e410..e604aef554825 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -708,8 +708,8 @@ def _cached_call_fine( ) for k, v in processed_modal_item.items(): - # Remove the extra batch dimension - processed_modal_items[k].append(v[0]) + # Remove the extra batch dimension (if it exists) + processed_modal_items[k].append(v.squeeze(0)) for k, vs in processed_modal_items.items(): # Try to merge elements into a single tensor From 4b3db3281d5ef9a232f475e8446abddad679edba Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Dec 2024 14:47:59 +0000 Subject: [PATCH 23/26] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_vl.py | 27 ++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index e39893cdf425c..60e550f77c054 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -813,6 +813,26 @@ def _get_hf_mm_data( return processor_data, passthrough_data + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + # Remove the extra dimension + if (not self.ctx.model_config.disable_mm_preprocessor_cache + and "pixel_values" in processed_outputs): + processed_outputs["pixel_values"] = \ + processed_outputs["pixel_values"].squeeze(0) + + return processed_outputs + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, @@ -945,9 +965,7 @@ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): return None return quant_config - def _validate_and_reshape_mm_tensor(self, - mm_input: Union[torch.Tensor, - List[torch.Tensor]], + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -957,7 +975,8 @@ def _validate_and_reshape_mm_tensor(self, return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}") + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) From dafbc7fb0dc57c7cb46c7a26d150d104b756c570 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 21 Dec 2024 06:21:41 +0000 Subject: [PATCH 24/26] Simplify Pixtral-HF Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 42 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 2be5dcd4a88de..f6e3cfdb3f561 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,4 @@ from functools import cached_property -from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -116,30 +115,29 @@ def get_max_llava_image_tokens(ctx: InputContext): class LlavaMultiModalProcessor(BaseMultiModalProcessor): - def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): - if getattr(hf_processor, "__is_patched__", False): - return # Already patched - - image_processor = hf_processor.image_processor # type: ignore - orig_preprocess = image_processor.preprocess - - def preprocess(__self, *args, **kwargs): - hf_inputs = orig_preprocess(*args, **kwargs) - hf_inputs["is_pixtral"] = torch.tensor(True) - return hf_inputs - - image_processor.preprocess = MethodType(preprocess, image_processor) + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: + return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) - hf_processor.__is_patched__ = True # type: ignore + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - hf_processor = self.ctx.get_hf_processor( - (LlavaProcessor, PixtralProcessor)) + images = mm_data.get("images", []) + assert isinstance(images, list) - if isinstance(hf_processor, PixtralProcessor): - self._patch_pixtral_processor(hf_processor) + is_pixtral = isinstance(self._get_hf_processor(), PixtralProcessor) + processed_outputs["is_pixtral"] = \ + torch.tensor([is_pixtral] * len(images)) - return hf_processor + return processed_outputs def _get_prompt_replacements( self, @@ -379,8 +377,8 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) + is_pixtral = kwargs.pop("is_pixtral", None) if pixel_values is None and image_embeds is None: return None From 38aaff87d3b518e212172b121a1db39ac26faa18 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 21 Dec 2024 06:24:09 +0000 Subject: [PATCH 25/26] Cleanup Signed-off-by: DarkLight1337 --- vllm/multimodal/inputs.py | 24 ++++++++--------- vllm/multimodal/processing.py | 51 +++++++++++++++++------------------ 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 138cc6a44c11a..e8f1a8f7bd228 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1,6 +1,6 @@ from collections import UserDict, defaultdict -from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, - TypedDict, TypeVar, Union, cast, final) +from collections.abc import Mapping, Sequence +from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final import numpy as np import torch @@ -44,7 +44,7 @@ """ # yapf: enable -MultiModalData: TypeAlias = Union[_T, List[_T]] +MultiModalData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -97,13 +97,13 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, - Tuple[torch.Tensor, ...]] +NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, + tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalKwargs.batch`. @@ -139,7 +139,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: # Only tensors (not lists) can be stacked. return stacked - tensors_ = cast(List[torch.Tensor], stacked) + tensors_ = cast(list[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. return tensors_ @@ -147,7 +147,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return torch.stack(tensors_) @staticmethod - def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -162,7 +162,7 @@ def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: # We need to consider the case where each item in the batch # contains different modalities (i.e. different keys). - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + item_lists = defaultdict[str, list[NestedTensors]](list) for inputs in inputs_list: for k, v in inputs.items(): @@ -207,16 +207,16 @@ class MultiModalInputsV2(TypedDict): prompt: str """The processed prompt text.""" - prompt_token_ids: List[int] + prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[List[int]] + token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: NotRequired[List[str]] + mm_hashes: NotRequired[list[str]] """The hashes of the multi-modal data.""" mm_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index e604aef554825..2c30ae98a5b38 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -587,7 +587,7 @@ def iter_placeholders( class ProcessorInputs(NamedTuple): - """Keyword arguments to :meth:`BaseMultiModalProcessor`""" + """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" prompt_text: str mm_data: MultiModalDataDict hf_mm_kwargs: Mapping[str, object] @@ -615,33 +615,14 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None: logger.debug("ProcessingCache: %s.hit_ratio = %.2f", name, cache_stats.hit_ratio) - def _iter_bytes_to_hash( - self, - key: str, - obj: object, - ) -> Iterable[tuple[bytes, bytes]]: - # Recursive cases - if isinstance(obj, (list, tuple)): - for i, elem in enumerate(obj): - yield from self._iter_bytes_to_hash(f"{key}.{i}", elem) - return - if isinstance(obj, dict): - for k, v in obj.items(): - yield from self._iter_bytes_to_hash(f"{key}.{k}", v) - return - - key_bytes = key.encode("utf-8") - + def _hash_item(self, obj: object) -> bytes: # Simple cases if isinstance(obj, str): - yield key_bytes, obj.encode("utf-8") - return + return obj.encode("utf-8") if isinstance(obj, bytes): - yield key_bytes, obj - return + return obj if isinstance(obj, Image): - yield key_bytes, obj.tobytes() - return + return obj.tobytes() # Convertible to NumPy arrays if isinstance(obj, torch.Tensor): @@ -649,14 +630,30 @@ def _iter_bytes_to_hash( if isinstance(obj, (int, float)): obj = np.array(obj) if isinstance(obj, np.ndarray): - yield key_bytes, obj.tobytes() - return + return obj.tobytes() logger.warning( "No serialization method found for %s. " "Falling back to pickle.", type(obj)) - yield key_bytes, pickle.dumps(obj) + return pickle.dumps(obj) + + def _iter_bytes_to_hash( + self, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for i, elem in enumerate(obj): + yield from self._iter_bytes_to_hash(f"{key}.{i}", elem) + elif isinstance(obj, dict): + for k, v in obj.items(): + yield from self._iter_bytes_to_hash(f"{key}.{k}", v) + else: + key_bytes = self._hash_item(key) + value_bytes = self._hash_item(obj) + yield key_bytes, value_bytes def _hash_kwargs(self, **kwargs: object) -> str: hasher = blake3() From 5fcb5d6451eef08cab4d7f407ff1abc0c688a625 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 21 Dec 2024 06:31:30 +0000 Subject: [PATCH 26/26] Fix Pixtral-HF Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 2c30ae98a5b38..1751873318523 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -706,7 +706,11 @@ def _cached_call_fine( for k, v in processed_modal_item.items(): # Remove the extra batch dimension (if it exists) - processed_modal_items[k].append(v.squeeze(0)) + # NOTE: v may be a list instead of a tensor + if len(v) == 1: + v = v[0] + + processed_modal_items[k].append(v) for k, vs in processed_modal_items.items(): # Try to merge elements into a single tensor