Skip to content

Commit

Permalink
Simplify the code for getting multiple items
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 2, 2025
1 parent 415b3d4 commit f00efc8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
42 changes: 24 additions & 18 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def build_elems(
self,
key: str,
batch: NestedTensors,
) -> list[MultiModalFieldElem]:
) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore

Expand All @@ -285,7 +285,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""

@staticmethod
def from_elems(elems: list[MultiModalFieldElem]) -> "MultiModalKwargsItem":
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})

@property
Expand All @@ -304,7 +304,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
The metadata :code:`items` enables us to obtain the keyword arguments
corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_num_items` and :meth:`get_item`.
:meth:`get_item` and :meth:`get_items`.
"""

@staticmethod
Expand All @@ -314,7 +314,7 @@ def from_hf_inputs(
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
elems_by_key = dict[str, list[MultiModalFieldElem]]()
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items():
batch = hf_inputs.get(key)
Expand Down Expand Up @@ -342,7 +342,7 @@ def from_hf_inputs(
return MultiModalKwargs.from_items(items)

@staticmethod
def from_items(items: list[MultiModalKwargsItem]) -> "MultiModalKwargs":
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
Expand Down Expand Up @@ -453,32 +453,38 @@ def __eq__(self, other: object) -> bool:
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))

def get_num_items(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
raise RuntimeError(
"`get_num_items` is not supported when "
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")

if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")

def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])

def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
if not self._items_by_modality:
raise RuntimeError(
"`get_item` is not supported when "
"MultiModalKwargs is not initialized with `items`")

if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")

self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]

def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]


MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ def process_inputs(
decoder_mm_data = decoder_inputs.multi_modal_data
if isinstance(decoder_mm_data, MultiModalKwargs):
precomputed_mm_inputs = [
MultiModalKwargs.from_items(
[decoder_mm_data.get_item(modality, item_idx)])
MultiModalKwargs.from_items([item])
for modality in decoder_mm_data.modalities
for item_idx in range(decoder_mm_data.get_num_items(modality))
for item in decoder_mm_data.get_items(modality)
]

# Apply MM mapper
Expand Down

0 comments on commit f00efc8

Please sign in to comment.