Skip to content

Commit

Permalink
[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#…
Browse files Browse the repository at this point in the history
…11674)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 2, 2025
1 parent a115ac4 commit 23c1b10
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 168 deletions.
252 changes: 116 additions & 136 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
final)

import numpy as np
import torch
Expand All @@ -11,7 +12,7 @@
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias

from vllm.utils import JSONTree, is_list_of, json_map_leaves
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves

_T = TypeVar("_T")

Expand Down Expand Up @@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:


@dataclass(frozen=True)
class MultiModalFieldItem:
"""
Contains metadata and data in :class:`MultiModalKwargs`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field: "BaseMultiModalField"
data: NestedTensors

Expand All @@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError

def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
return MultiModalFieldItem(self, data)
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldElem(self, data)

def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")

data = self._reduce_data([item.data for item in batch])

return self._build_item(data)
return self._build_elem(data)


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
directly indexing into the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
"""

def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
return [self._build_item(item) for item in batch]
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_elem(item) for item in batch]

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape == first_shape for item in batch):
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)

return batch
Expand All @@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
slicing along the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
"""

def build_items(
def build_elems(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldItem]:
return [self._build_item(batch[slice_]) for slice_ in slices]
) -> list[MultiModalFieldElem]:
return [self._build_elem(batch[slice_]) for slice_ in slices]

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape[1:] == first_shape[1:] for item in batch):
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)

return [elem for item in batch for elem in item]
return [e for elem in batch for e in elem]


class MultiModalFieldConfig:
Expand Down Expand Up @@ -267,115 +265,111 @@ def __init__(
) -> None:
super().__init__()

self._field_cls = field_cls
self._modality = modality
self._field_config = field_config
self.field_cls = field_cls
self.modality = modality
self.field_config = field_config

def build_items(
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> list[MultiModalFieldItem]:
field = self._field_cls(key=key, modality=self._modality)
return field.build_items(batch, **self._field_config) # type: ignore
) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore


class MultiModalKwargs(UserDict[str, NestedTensors]):
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of :class:`MultiModalFieldElem`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.

The metadata :code:`items_by_key` defines how to split batched keyword
arguments corresponding to each data item in :class:`MultiModalDataItems`:
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})

- For a keyword argument, we can access the :code:`i` th item in the batch
via :code:`items_by_key[key][i]`.
- We can gather the keyword arguments belonging to a modality by finding
the keys with items that belong to that modality, then accessing
the :code:`i` th item in the batch for each such key.
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))

Example:

.. code-block:: python
# All items belong to the "image" modality
items_by_key={
"pixel_values": [a, b, c, d], # "image" modality
"image_grid_thw": [e, f, g, h], # "image" modality
"pixel_values_video": [h, i, j], # "video" modality
"video_grid_thw": [k, l, m], # "video" modality
}
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
- The keyword arguments belonging to the first image are
:code:`{"pixel_values": a, "image_grid_thw": e}`.
- The keyword arguments belonging to the second video are
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
The metadata :code:`items` enables us to obtain the keyword arguments
corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_item` and :meth:`get_items`.
"""

@staticmethod
def from_hf_inputs(
hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig],
*,
enable_sanity_checks: bool = False,
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
items_by_key = {
key: config.build_items(key, batch)
for key, config in config_by_key.items()
if (batch := hf_inputs.get(key)) is not None
}

return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items():
batch = hf_inputs.get(key)
if batch is not None:
elems = config.build_elems(key, batch)
if len(elems) > 0:
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)

items = list[MultiModalKwargsItem]()
for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

if len(set(batch_sizes.values())) > 1:
raise ValueError(
f"Cannot merge different batch sizes for {modality=}! "
f"Found: {batch_sizes=}")

batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))

return MultiModalKwargs.from_items(items)

@staticmethod
def from_items_by_key(
items_by_key: Mapping[str, list[MultiModalFieldItem]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)

data = {
key: items[0].field.reduce(items).data
for key, items in items_by_key.items() if len(items) > 0
key: elems[0].field.reduce(elems).data
for key, elems in elems_by_key.items() if len(elems) > 0
}

return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)
return MultiModalKwargs(data, items=items)

def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
enable_sanity_checks: bool = False,
items: Optional[Sequence[MultiModalKwargsItem]] = None,
) -> None:
super().__init__(data)

# Shallow copy to avoid footgun in case a defaultdict is passed in
self._items_by_key = dict(items_by_key)
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)

keys_by_modality = defaultdict[str, set[str]](set)
for key, items in items_by_key.items():
for item in items:
keys_by_modality[item.field.modality].add(key)

self._keys_by_modality = dict(keys_by_modality)

if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)
@property
def modalities(self):
return self._items_by_modality.keys()

@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
Expand Down Expand Up @@ -452,58 +446,44 @@ def as_kwargs(
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_key != other._items_by_key:
if self._items_by_modality != other._items_by_modality:
return False

ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))

def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
return self._items_by_key[key][item_index]
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
raise RuntimeError(
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")

def get_items_by_modality(
self,
modality: str,
item_index: int,
) -> Mapping[str, MultiModalFieldItem]:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
if modality not in self._keys_by_modality:
available_modalities = set(self._keys_by_modality.keys())
if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")

keys_to_gather = self._keys_by_modality[modality]
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])

return {
key: self.get_item(key, item_index)
for key in keys_to_gather if key in self
}
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]

@staticmethod
def from_items_by_modality(
items_by_modality: Mapping[str, list[Mapping[str,
MultiModalFieldItem]]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Construct a new :class:`MultiModalKwargs` from multiple items returned
by :meth:`get_fields_by_modality`.
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
for fields in items_by_modality.values():
for field in fields:
for k, v in field.items():
items_by_key[k].append(v)

return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]


MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
Expand Down
Loading

0 comments on commit 23c1b10

Please sign in to comment.