Skip to content

Commit

Permalink
Fix ultravox
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 21, 2024
1 parent c01d38a commit 84f02fb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
7 changes: 7 additions & 0 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def _call_hf_processor(
audios = mm_data.pop("audios", [])

if not audios:
if not mm_data:
# Text-only input not supported in composite processor
prompt_ids = self._get_tokenizer().encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]),
tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
Expand Down Expand Up @@ -153,6 +159,7 @@ def _get_mm_field_tags(
) -> Mapping[str, MultiModalFieldTag]:
return dict(
audio_features=MultiModalFieldTags.indexed("audio"),
audio_token_len=MultiModalFieldTags.indexed("audio"),
audio_embeds=MultiModalFieldTags.indexed("audio"),
)

Expand Down
28 changes: 15 additions & 13 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,14 @@ def _call_hf_processor(

def _apply_hf_processor(
self,
prompt: str,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

processed_data = self._call_hf_processor(
prompt=prompt,
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
)
Expand All @@ -827,15 +827,15 @@ def _apply_hf_processor(

def _cached_apply_hf_processor(
self,
prompt: str,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
cache = self.cache

if cache is None:
return self._apply_hf_processor(
prompt=prompt,
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
Expand All @@ -860,15 +860,17 @@ def _cached_apply_hf_processor(

# Rely on our placeholder replacement logic instead of HF
# to insert the placeholder tokens
prompt_ids = _encode(self._get_tokenizer(),
prompt,
add_special_tokens=True)

_, mm_missing_kwargs = self._apply_hf_processor(
prompt=prompt,
mm_items=mm_missing_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
prompt_ids = self._get_tokenizer().encode(prompt_text)

if mm_missing_items:
_, mm_missing_kwargs = self._apply_hf_processor(
prompt_text=prompt_text,
mm_items=mm_missing_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
else:
# Avoid unnecessary tokenization of the prompt text
mm_missing_kwargs = MultiModalKwargs({})

mm_missing_next_idx = {modality: 0 for modality in mm_missing_items}

Expand Down

0 comments on commit 84f02fb

Please sign in to comment.