Skip to content

Commit

Permalink
Abstract out parsing of multi-modal data
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 30, 2024
1 parent 628ec6c commit b110c58
Show file tree
Hide file tree
Showing 14 changed files with 538 additions and 299 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand All @@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):

def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
data: ModalityData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorInput
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -179,7 +181,9 @@ def _get_prompt_replacements(
assert isinstance(vision_config, PixtralVisionConfig)

def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorInput)
image_size = images.get_image_size(item_idx)

(
num_width_tokens,
num_height_tokens,
Expand Down Expand Up @@ -591,7 +595,7 @@ def apply(

result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

mm_items = self._get_mm_items(mm_data)
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_kwargs = result["mm_kwargs"]

Expand Down
19 changes: 12 additions & 7 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorInput
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -381,20 +382,24 @@ def _get_prompt_replacements(
assert isinstance(bos_token_id, int)

def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorInput)
image_size = images.get_image_size(item_idx)

num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)

return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]

num_images = mm_items.get_item_count("image")

return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:len(mm_items.images)]
) for image_token in image_tokens[:num_images]
]

def _apply_prompt_replacements(
Expand Down
22 changes: 11 additions & 11 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import numpy as np
import torch
Expand All @@ -38,10 +38,12 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -99,15 +101,13 @@ def _get_hf_processor(
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore

def _get_hf_mm_data(
def _to_mm_items(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_hf_mm_data(mm_items)
parser = MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
return parser.parse_mm_data(mm_data)

def _call_hf_processor(
self,
Expand Down
152 changes: 73 additions & 79 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -55,15 +54,16 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
Expand Down Expand Up @@ -719,61 +719,84 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key="video")


class Qwen2VLMultiModalDataItems(MultiModalDataItems):
class Qwen2EmbeddingsInput(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):

@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()

for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data)

return multi_data
self.modality = modality

def get_item_counts(self) -> Mapping[str, int]:
return {
m: (
len(items[f"{m}_grid_thw"]) # type: ignore
if isinstance(items, dict) else len(items))
for m, items in self.items()
}
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]

def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")

def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])

def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]

out[k] = v

return out

def get_processor_data(self) -> Mapping[str, object]:
return {}

def get_passthrough_data(self) -> Mapping[str, object]:
return self.data


class Qwen2ImageEmbeddingsInput(Qwen2EmbeddingsInput):

def __init__(self, data: dict) -> None:
super().__init__(data, "image")

def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())

class Qwen2VideoEmbeddingsInput(Qwen2EmbeddingsInput):

def __init__(self, data: dict) -> None:
super().__init__(data, "video")


class Qwen2MultiModalDataParser(MultiModalDataParser):

def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2ImageEmbeddingsInput(data)

return super()._parse_image_data(data)

def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2VideoEmbeddingsInput(data)

return super()._parse_video_data(data)


class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):

def _get_mm_items(
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
return Qwen2MultiModalDataParser().parse_mm_data(mm_data)

def _get_hf_processor(
self,
Expand All @@ -796,35 +819,6 @@ def _get_hf_processor(

return hf_processor

def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()

for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v

return processor_data, passthrough_data

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
Expand Down
22 changes: 11 additions & 11 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import math
from functools import cached_property, lru_cache
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import numpy as np
import torch
Expand All @@ -24,10 +24,12 @@
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
Expand Down Expand Up @@ -85,15 +87,13 @@ def _get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore

def _get_hf_mm_data(
def _to_mm_items(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_hf_mm_data(mm_items)
parser = MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
return parser.parse_mm_data(mm_data)

def _call_hf_processor(
self,
Expand Down
Loading

0 comments on commit b110c58

Please sign in to comment.