Skip to content

Commit

Permalink
[Model] Refactor Ultravox to use merged input processor (#11198)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
Isotr0py and DarkLight1337 authored Dec 16, 2024
1 parent bddbbcb commit d927dbc
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 154 deletions.
10 changes: 5 additions & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
'role': 'user',
'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def iter_params(self, model_name: str):
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder]
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def server():
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

AudioTuple = Tuple[np.ndarray, int]

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
Expand Down Expand Up @@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _placeholder_str(self, modality: ModalityStr,
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
return "<|audio|>"
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
Expand Down
244 changes: 104 additions & 140 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,39 @@

import math
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)

_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25


Expand Down Expand Up @@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):

return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}


def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}


def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)

return DummyData(seq_data, mm_dict, ranges)


def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

if len(data) == 0:
return MultiModalKwargs()

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")

(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor

def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
Expand All @@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}

minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))

single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]

# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))

return MultiModalKwargs({"audio_features": audio_features})


def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)

audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]

# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())

return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)

if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement

def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len

return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]

feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]

audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate

audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}

return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)


class StackAudioFrames(nn.Module):
Expand Down Expand Up @@ -332,11 +298,9 @@ def forward(
return hidden_states


@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
Loading

0 comments on commit d927dbc

Please sign in to comment.