From 5ca5bace38629ea612ad12334944097a2a9ab01a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 18 Dec 2024 16:24:10 +0000 Subject: [PATCH] Fix error when input list is empty Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_audio.py | 27 +++++++++++++++++++++-- vllm/model_executor/models/ultravox.py | 8 ++++++- vllm/multimodal/processing.py | 3 ++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 1cd70a5bf667e..8699990ee814b 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,8 +26,10 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature, Qwen2AudioConfig, Qwen2AudioProcessor -from transformers.models.qwen2_audio import Qwen2AudioEncoder +from transformers import BatchFeature, ProcessorMixin +from transformers.models.qwen2_audio import (Qwen2AudioConfig, + Qwen2AudioEncoder, + Qwen2AudioProcessor) from transformers.models.whisper import WhisperFeatureExtractor from vllm.attention import AttentionMetadata @@ -102,6 +104,27 @@ def _get_processor_data( return super()._get_processor_data(mm_items) + def _call_hf_processor( + self, + hf_processor: ProcessorMixin, + prompt: str, + processor_data: Mapping[str, object], + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: + if processor_data.get("audios"): + feature_extractor = self._get_feature_extractor() + mm_processor_kwargs = dict( + **mm_processor_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + + return super()._call_hf_processor( + hf_processor, + prompt=prompt, + processor_data=processor_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 8a3e56ec32a8a..0243b9c9f8b01 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -93,7 +93,7 @@ def _call_hf_processor( processor_data: Mapping[str, object], mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - if "audios" not in processor_data: + if not processor_data.get("audios"): return super()._call_hf_processor( hf_processor, prompt=prompt, @@ -103,6 +103,12 @@ def _call_hf_processor( shared_processor_data = dict(processor_data) + feature_extractor = self._get_feature_extractor() + mm_processor_kwargs = dict( + **mm_processor_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + # Already resampled by _get_processor_data audios = shared_processor_data.pop("audios") assert is_list_of(audios, np.ndarray) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 0abbdcd70db88..741a7e7461795 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -661,7 +661,8 @@ def _get_processor_data( if isinstance(v, torch.Tensor) and v.ndim == 3: # Pass through embedding inputs (single) passthrough_data[f"{k}_embeds"] = [v] - elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: + elif (is_list_of(v, torch.Tensor) and len(v) > 0 + and v[0].ndim == 2): # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v else: