Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Merged multimodal processor for Qwen2-Audio #11303

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
2: "What sport and what nursery rhyme are referenced?"
}

# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.


# Ultravox 0.3
def run_ultravox(question: str, audio_count: int):
Expand All @@ -33,6 +37,8 @@ def run_ultravox(question: str, audio_count: int):
add_generation_prompt=True)

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
Expand Down
9 changes: 4 additions & 5 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from vllm.multimodal.audio import resample_audio
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

Expand Down Expand Up @@ -130,16 +131,14 @@ def process(hf_inputs: BatchEncoding, **kwargs):
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
import librosa

hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
audios=[(resample_audio(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]

Expand Down
100 changes: 70 additions & 30 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type)
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
Optional, Protocol, Union)

from torch import nn
from transformers import PretrainedConfig, ProcessorMixin
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never

from vllm.logger import init_logger
Expand All @@ -26,6 +26,7 @@
logger = init_logger(__name__)

C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)


@dataclass(frozen=True)
Expand All @@ -38,24 +39,28 @@ class InputContext:
model_config: "ModelConfig"
"""The configuration of the model."""

def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
def get_hf_config(
self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
/,
) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.

Raises:
TypeError: If the model is not of the specified type.
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")

return hf_config

def get_hf_image_processor_config(self) -> Dict[str, Any]:
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
Expand All @@ -74,58 +79,93 @@ def get_mm_config(self):

return mm_config

def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.

Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = {**base_kwargs, **kwargs}

return cached_get_processor(
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")

return hf_processor


@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""

def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = {**base_kwargs, **kwargs}

return cached_get_processor(
self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)

def resolve_hf_processor_call_kwargs(
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
) -> BatchFeature:
assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

return resolve_mm_processor_kwargs(
merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)

try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

raise RuntimeError(msg) from exc


N = TypeVar("N", bound=Type[nn.Module])
N = TypeVar("N", bound=type[nn.Module])


class DummyData(NamedTuple):
Expand Down Expand Up @@ -232,7 +272,7 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)

Expand All @@ -257,7 +297,7 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)

Expand Down Expand Up @@ -368,14 +408,14 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_model_input_processor(self, model_cls: Type[nn.Module]):
def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)

def _ensure_mm_kwargs(
self,
inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any],
mm_processor_kwargs: dict[str, Any],
):
if inputs["type"] == "token":
# In case the input processor for that model fails to set it
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def preprocess(__self, *args, **kwargs):
hf_processor.__is_patched__ = True # type: ignore

def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))

if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
Expand Down
16 changes: 11 additions & 5 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -330,20 +329,27 @@ def _get_hf_processor(
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()

def _apply_hf_processor(
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
mm_data: MultiModalDataDict,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._apply_hf_processor(
prompt, mm_data, mm_processor_kwargs)
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)

# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids

return processed_outputs

def _get_prompt_replacements(
Expand Down
Loading
Loading