Skip to content

Commit

Permalink
Fix error when input list is empty
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 18, 2024
1 parent 0164939 commit 5fed009
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 9 deletions.
43 changes: 39 additions & 4 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, Qwen2AudioConfig, Qwen2AudioProcessor
from transformers.models.qwen2_audio import Qwen2AudioEncoder
from transformers import BatchFeature, ProcessorMixin
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor

from vllm.attention import AttentionMetadata
Expand Down Expand Up @@ -102,6 +104,35 @@ def _get_processor_data(

return super()._get_processor_data(mm_items)

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])

if audios:
processor_data["audios"] = audios

feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass

return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
Expand All @@ -111,8 +142,12 @@ def _get_prompt_replacements(
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index

_, audio_output_lengths = _get_feat_extract_output_lengths(
hf_inputs.feature_attention_mask.sum(-1))
feature_attention_mask = hf_inputs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))

def get_replacement_qwen2_audio(item_idx: int):
return [placeholder] * audio_output_lengths[item_idx]
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,24 @@ def _call_hf_processor(
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if "audios" not in processor_data:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])

if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)

shared_processor_data = dict(processor_data)
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)

# Already resampled by _get_processor_data
audios = shared_processor_data.pop("audios")
assert is_list_of(audios, np.ndarray)

# Ultravox processor doesn't support multiple inputs,
Expand All @@ -113,7 +119,7 @@ def _call_hf_processor(
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**shared_processor_data, audio=audio)
item_processor_data = dict(**processor_data, audio=audio)

item_outputs = super()._call_hf_processor(
hf_processor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ def _get_processor_data(
if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
Expand Down

0 comments on commit 5fed009

Please sign in to comment.