From cfbd16412f4a480fdfabed3eb9552980aa3884a2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 11 Dec 2024 22:00:54 +0000 Subject: [PATCH 01/44] add model def --- vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/whisper.py | 778 +++++++++++++++++++++++++ 2 files changed, 779 insertions(+) create mode 100644 vllm/model_executor/models/whisper.py diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..1ef6eaa90e5b8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -147,6 +147,7 @@ "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "WhisperForConditionalGeneration": ("whipser", "WhisperForConditionalGeneration"), # noqa: E501 } _SPECULATIVE_DECODING_MODELS = { diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000000000..60b94a40cb539 --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,778 @@ +import math +from functools import lru_cache +from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union + +import librosa +import numpy as np +import torch +from torch import nn +from transformers import WhisperConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import FastGELU +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.utils import consecutive_placeholder_ranges +from vllm.sequence import SequenceData +from xformers import ops as xops + +from .interfaces import SupportsMultiModal +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + + +class WhisperPositionalEmbedding(nn.Embedding): + def __init__(self, num_positions: int, embedding_dim: int, + padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + + def forward(self, position_ids): + return self.weight[position_ids] + + +class WhisperAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = max(1, self.num_heads // tp_size) + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = False, + quant_config=quant_config, + prefix=f"{prefix}.k_proj", + ) + self.v_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config, + prefix=f"{prefix}.v_proj", + ) + self.q_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.out_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + +class WhisperEncoderAttention(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + quant_config=quant_config, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + hidden_states: torch.Tensor, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) + output, _ = self.out_proj(attn_output) + return output + + +class WhisperDecoderAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + quant_config=quant_config, + cache_config=cache_config, + prefix=prefix, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata: AttentionMetadata = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperDecoderCrossAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + quant_config=quant_config, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ): + sizes = hidden_states.size() + if len(sizes) == 3: + bsz, tgt_len, _ = sizes + else: + tgt_len, _ = sizes + + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + + q = self._shape(q, -1, 1) + k = self._shape(k, -1, 1) + v = self._shape(v, -1, 1) + + attn_output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=None, + p=0.0, + scale=None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], + ) + + attn_output = attn_output.reshape(-1, self.embed_dim) + output, _ = self.out_proj(attn_output) + return output + + +class WhisperEncoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache + + self.embed_dim = config.d_model + self.self_attn = WhisperEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = FastGELU() + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.encoder_ffn_dim, + bias = True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size = config.encoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +class WhisperDecoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + + self.embed_dim = config.d_model + self.self_attn = WhisperDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.activation_fn = FastGELU() + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperDecoderCrossAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.encoder_attn", + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = RowParallelLinear( + input_size = self.embed_dim, + output_size = config.decoder_ffn_dim, + bias = True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size = config.decoder_ffn_dim, + output_size = self.embed_dim, + bias = True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperEncoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.start_layer, self.end_layer, self.layers = make_layers( + config.encoder_layers, + lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + with torch.no_grad(): + self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) + + def forward( + self, + input_features, + ): + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + inputs_embeds = inputs_embeds.permute(1, 0) + + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer(hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class WhisperDecoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + self.start_layer, self.end_layer, self.layers = make_layers( + config.decoder_layers, + lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids, + positions: torch.Tensor, + encoder_hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + past_key_values = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + positions = self.embed_positions(positions) + hidden_states = inputs_embeds + positions + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + +class WhisperModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = WhisperEncoder(vllm_config=vllm_config, + prefix=f"{prefix}.encoder") + self.decoder = WhisperDecoder(vllm_config=vllm_config, + prefix=f"{prefix}.decoder") + + def forward( + self, + input_features: torch.FloatTensor, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + encoder_outputs = self.encoder(input_features) + + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_outputs, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + + +def dummy_data_for_whisper_audio(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_audios = mm_counts["audio"] + max_tokens_per_audio = get_max_whisper_audio_audio_tokens(ctx) + max_llm_audio_tokens = max_tokens_per_audio * num_audios + if seq_len - max_llm_audio_tokens - 2 < 0: + raise RuntimeError( + f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " + "please increase max_model_len or reduce audio limit by " + "--limit-mm-per-prompt.") + + audio_token_index = ctx.model_config.hf_config.audio_token_index + + dummy_seqdata = SequenceData.from_prompt_token_counts( + (audio_token_index, max_llm_audio_tokens), + (0, seq_len - max_llm_audio_tokens), + ) + dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) + return DummyData( + dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { + "audio": + consecutive_placeholder_ranges(num_items=num_audios, + item_size=max_tokens_per_audio) + }) + + +def get_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + **kwargs, +): + """Gets a processor for the given model name via HuggingFace. + + Derived from `vllm.transformers_utils.image_processor.get_image_processor`. + """ + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor + + try: + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the processor. If the processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +cached_get_processor = lru_cache(get_processor) + + +def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def get_max_whisper_audio_audio_tokens(ctx: InputContext) -> int: + max_source_position = ( + ctx.model_config.hf_config.audio_config.max_source_positions) + output_lengths = (max_source_position - 2) // 2 + 1 + return output_lengths + + +def input_processor_for_whisper_audio( + ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return inputs + + audios = multi_modal_data["audio"] + if not isinstance(audios, list): + audios = [audios] + + if len(audios) == 0: + return inputs + + processor = cached_get_processor(ctx.model_config.model) + resampled_audios = [ + librosa.resample(audio, + orig_sr=sampling_rate, + target_sr=processor.feature_extractor.sampling_rate) + for audio, sampling_rate in audios + ] + audio_input_lengths = np.array( + [min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios]) + + audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( + audio_input_lengths) + + audio_token_index = ctx.model_config.hf_config.audio_token_index + + input_ids = inputs['prompt_token_ids'] + + new_input_ids = [] + audio_num = input_ids.count(audio_token_index) + assert len(audio_input_lengths) == audio_num, \ + (f'The text input contains {audio_num} audio tokens, ' + f'but {len(audio_input_lengths)} audios provided') + start = 0 + for audio_idx in range(audio_num): + end = input_ids.index(audio_token_index, start) + new_input_ids.extend(input_ids[start:end]) # text part + + new_input_ids.extend([audio_token_index] * + audio_output_lengths[audio_idx]) + start = end + 1 + new_input_ids.extend(input_ids[start:]) + + return token_inputs( + prompt_token_ids=new_input_ids, + prompt=inputs['prompt'], + multi_modal_data=multi_modal_data, + ) + + +def input_mapper_for_whisper_audio( + ctx: InputContext, + multi_modal_data: Union[np.ndarray, List[np.ndarray]], +) -> MultiModalKwargs: + """Input mapper for Qwen2-Audio.""" + if not isinstance(multi_modal_data, list): + multi_modal_data = [multi_modal_data] + + if len(multi_modal_data) == 0: + return MultiModalKwargs() + + processor = cached_get_processor(ctx.model_config.model) + audio_feature_extractor = processor.feature_extractor + if audio_feature_extractor is None: + raise RuntimeError( + "No HuggingFace audio_feature_extractor is available " + "to process the audio object") + + try: + resampled_audios = [ + librosa.resample( + audio, + orig_sr=sampling_rate, + target_sr=processor.feature_extractor.sampling_rate) + for audio, sampling_rate in multi_modal_data + ] + batch_data = audio_feature_extractor(resampled_audios, + sampling_rate=16000, + return_attention_mask=True, + padding="max_length", + return_tensors="pt").data + batch_data["feature_attention_mask"] = batch_data.pop("attention_mask") + except Exception: + logger.error("Failed to process audio (%s)", multi_modal_data) + raise + + return MultiModalKwargs(batch_data) + + +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_whisper_audio) +@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper_audio) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", + input_mapper_for_whisper_audio) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_whisper_audio_audio_tokens) +class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + + self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) + self.unpadded_vocab_size = config.vocab_size + self.proj_out = RowParallelLinear( + input_size = config.d_model, + output_size = config.vocab_size, + bias = False, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + whisper_data: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + decoder_outputs = self.model( + input_features=whisper_data, + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.proj_out.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if name == 'model.decoder.embed_tokens.weight': + param = params_dict['proj_out.weight'] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + param = params_dict[name] From ced01413bdaadd8c4be22c657b872be38ba218a4 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 00:51:36 +0000 Subject: [PATCH 02/44] WIP --- examples/offline_inference_audio_language.py | 25 +++++++++++++++++++- vllm/model_executor/models/registry.py | 2 +- vllm/model_executor/models/whisper.py | 6 ++--- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 050b791b62adb..cce91a602ca2c 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -61,7 +61,30 @@ def run_qwen2_audio(question: str, audio_count: int): return llm, prompt, stop_token_ids -model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio} +# Whisper +def run_whisper(question: str, audio_count: int): + model_name = "openai/whisper-large-v3" + + llm = LLM(model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}) + + audio_in_prompt = "".join([ + f"Audio {idx+1}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) + ]) + + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = None + return llm, prompt, stop_token_ids + + +model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio, + "whisper": run_whisper} def main(args): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1ef6eaa90e5b8..58eb52d065715 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -147,7 +147,7 @@ "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 - "WhisperForConditionalGeneration": ("whipser", "WhisperForConditionalGeneration"), # noqa: E501 + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } _SPECULATIVE_DECODING_MODELS = { diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 60b94a40cb539..194d63b77949d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -278,7 +278,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - cache_config = vllm_config.cache + cache_config = vllm_config.cache_config self.embed_dim = config.d_model self.self_attn = WhisperEncoderAttention( @@ -527,7 +527,7 @@ def forward( def dummy_data_for_whisper_audio(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): + mm_counts: Mapping[str, int]): num_audios = mm_counts["audio"] max_tokens_per_audio = get_max_whisper_audio_audio_tokens(ctx) max_llm_audio_tokens = max_tokens_per_audio * num_audios @@ -605,7 +605,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): def get_max_whisper_audio_audio_tokens(ctx: InputContext) -> int: max_source_position = ( - ctx.model_config.hf_config.audio_config.max_source_positions) + ctx.model_config.hf_config.max_source_positions) output_lengths = (max_source_position - 2) // 2 + 1 return output_lengths From 248bafbc4fa02506231612d844d002dcaeafc4cc Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 02:48:12 +0000 Subject: [PATCH 03/44] WIP, passes profile run --- examples/offline_inference_audio_language.py | 2 +- vllm/model_executor/models/whisper.py | 136 +++++++++++++++---- vllm/worker/enc_dec_model_runner.py | 2 + 3 files changed, 113 insertions(+), 27 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index cce91a602ca2c..9fb490e0a40a0 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -67,7 +67,7 @@ def run_whisper(question: str, audio_count: int): llm = LLM(model=model_name, max_model_len=4096, - max_num_seqs=5, + max_num_seqs=1, limit_mm_per_prompt={"audio": audio_count}) audio_in_prompt = "".join([ diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 194d63b77949d..46fbc8bb9172a 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import nn -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperProcessor from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig @@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -439,9 +440,10 @@ def forward( self, input_features, ): + print(self.conv1.weight.dtype, self.conv1.bias.dtype, input_features.dtype) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - inputs_embeds = inputs_embeds.permute(1, 0) + inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight @@ -483,6 +485,7 @@ def forward( attn_metadata: AttentionMetadata, past_key_values = None, ): + print(self.max_target_positions, positions, input_ids.shape, positions.shape) inputs_embeds = self.embed_tokens(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions @@ -498,6 +501,7 @@ def forward( hidden_states = self.layer_norm(hidden_states) return hidden_states + class WhisperModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -526,24 +530,39 @@ def forward( return decoder_outputs +def _get_dummy_seq_data(seq_len: int, + whisper_config: WhisperConfig) -> SequenceData: + # '<|startoftranscript|><|en|><|transcribe|>' + token_ids = [50258, 50259, 50360] + return SequenceData(token_ids) + + +def _get_dummy_values(whisper_config: WhisperConfig) -> torch.Tensor: + values_dtype = torch.float16 + + return torch.zeros((30 * whisper_config.sample_rate), dtype=values_dtype) + + def dummy_data_for_whisper_audio(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_audios = mm_counts["audio"] max_tokens_per_audio = get_max_whisper_audio_audio_tokens(ctx) max_llm_audio_tokens = max_tokens_per_audio * num_audios - if seq_len - max_llm_audio_tokens - 2 < 0: - raise RuntimeError( - f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " - "please increase max_model_len or reduce audio limit by " - "--limit-mm-per-prompt.") - - audio_token_index = ctx.model_config.hf_config.audio_token_index - - dummy_seqdata = SequenceData.from_prompt_token_counts( - (audio_token_index, max_llm_audio_tokens), - (0, seq_len - max_llm_audio_tokens), - ) + # if seq_len - max_llm_audio_tokens - 2 < 0: + # raise RuntimeError( + # f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " + # "please increase max_model_len or reduce audio limit by " + # "--limit-mm-per-prompt.") + + audio_token_index = 0 # ctx.model_config.hf_config.audio_token_index + + dummy_seqdata = SequenceData.from_prompt_token_counts((0, seq_len)) + # dummy_seqdata = SequenceData.from_prompt_token_counts( + # (audio_token_index, max_llm_audio_tokens), + # (0, seq_len - max_llm_audio_tokens), + # ) dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) + print("dummy_audio", dummy_audio.shape) return DummyData( dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { "audio": @@ -552,6 +571,42 @@ def dummy_data_for_whisper_audio(ctx: InputContext, seq_len: int, }) +def get_whisper_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +) -> WhisperProcessor: + """Gets an whisper processor for the given model name via HuggingFace.""" + try: + processor: WhisperProcessor = WhisperProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the whisper processor. If the whisper processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +cached_get_whisper_processor = lru_cache(get_whisper_processor) + + def get_processor( processor_name: str, *args, @@ -600,6 +655,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): """ input_lengths = (input_lengths - 1) // 2 + 1 output_lengths = (input_lengths - 2) // 2 + 1 + print("_get_feat_extract_output_lengths", input_lengths, output_lengths) return input_lengths, output_lengths @@ -607,12 +663,14 @@ def get_max_whisper_audio_audio_tokens(ctx: InputContext) -> int: max_source_position = ( ctx.model_config.hf_config.max_source_positions) output_lengths = (max_source_position - 2) // 2 + 1 + print("get_max_whisper_audio_audio_tokens", output_lengths) return output_lengths def input_processor_for_whisper_audio( ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") + print("input_processor_for_whisper_audio", multi_modal_data) if multi_modal_data is None or "audio" not in multi_modal_data: return inputs @@ -623,7 +681,15 @@ def input_processor_for_whisper_audio( if len(audios) == 0: return inputs - processor = cached_get_processor(ctx.model_config.model) + processor = cached_get_whisper_processor(ctx.model_config.model) + print("audios", audios) + + whisper_data = processor( + audio, + sampling_rate = self.whisper_config.sample_rate, + return_tensors = 'pt', + ) + whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] resampled_audios = [ librosa.resample(audio, orig_sr=sampling_rate, @@ -667,13 +733,25 @@ def input_mapper_for_whisper_audio( multi_modal_data: Union[np.ndarray, List[np.ndarray]], ) -> MultiModalKwargs: """Input mapper for Qwen2-Audio.""" + print("input_mapper_for_whisper_audio", multi_modal_data) if not isinstance(multi_modal_data, list): multi_modal_data = [multi_modal_data] if len(multi_modal_data) == 0: return MultiModalKwargs() - processor = cached_get_processor(ctx.model_config.model) + processor = cached_get_whisper_processor(ctx.model_config.model) + + resampled_audios = [ + processor( + audio, + sampling_rate=sampling_rate, + return_tensors='pt', + ).to(torch.float16).input_features[0] + for audio, sampling_rate in multi_modal_data + ] + print([audio.shape for audio in resampled_audios]) + audio_feature_extractor = processor.feature_extractor if audio_feature_extractor is None: raise RuntimeError( @@ -688,6 +766,7 @@ def input_mapper_for_whisper_audio( target_sr=processor.feature_extractor.sampling_rate) for audio, sampling_rate in multi_modal_data ] + print([audio.shape for audio in resampled_audios]) batch_data = audio_feature_extractor(resampled_audios, sampling_rate=16000, return_attention_mask=True, @@ -697,7 +776,8 @@ def input_mapper_for_whisper_audio( except Exception: logger.error("Failed to process audio (%s)", multi_modal_data) raise - + batch_data["input_features"] = batch_data["input_features"].squeeze(dim=0) + print("input_mapper_for_whisper_audio", batch_data["input_features"].shape) return MultiModalKwargs(batch_data) @@ -718,12 +798,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size - self.proj_out = RowParallelLinear( - input_size = config.d_model, - output_size = config.vocab_size, - bias = False, - quant_config=quant_config, - ) + # self.proj_out = RowParallelLinear( + # input_size = config.d_model, + # output_size = config.vocab_size, + # bias = False, + # quant_config=quant_config, + # ) + self.proj_out = ParallelLMHead(config.vocab_size, + config.d_model, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -731,15 +814,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - whisper_data: torch.Tensor, + #whisper_data: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + **kwargs, ) -> torch.Tensor: decoder_outputs = self.model( - input_features=whisper_data, + input_features=kwargs["input_features"].to(torch.float16), input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -749,7 +833,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.proj_out.weight, hidden_states, + logits = self.logits_processor(self.proj_out, hidden_states, sampling_metadata) return logits diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 82824faa6629a..36d7d5114c9c8 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -175,6 +175,7 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} + print(multi_modal_kwargs.keys()) with set_forward_context(model_input.attn_metadata): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, @@ -280,6 +281,7 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = min(seq_len, 448) batch_size += seq_len decoder_dummy_data = self.input_registry \ From 7329b2d1d0cf9adaa131f064b1a261ab60c8c6b8 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 21:35:02 +0000 Subject: [PATCH 04/44] WIP --- examples/offline_inference_audio_language.py | 1 + vllm/model_executor/models/whisper.py | 233 +++---------------- vllm/transformers_utils/config.py | 1 + vllm/worker/worker.py | 1 + 4 files changed, 34 insertions(+), 202 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 9fb490e0a40a0..91e5f994db32f 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -68,6 +68,7 @@ def run_whisper(question: str, audio_count: int): llm = LLM(model=model_name, max_model_len=4096, max_num_seqs=1, + enforce_eager=True, limit_mm_per_prompt={"audio": audio_count}) audio_in_prompt = "".join([ diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 46fbc8bb9172a..ba84f60d4f33d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -194,7 +194,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config + quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -530,47 +531,17 @@ def forward( return decoder_outputs -def _get_dummy_seq_data(seq_len: int, - whisper_config: WhisperConfig) -> SequenceData: - # '<|startoftranscript|><|en|><|transcribe|>' - token_ids = [50258, 50259, 50360] - return SequenceData(token_ids) - - -def _get_dummy_values(whisper_config: WhisperConfig) -> torch.Tensor: - values_dtype = torch.float16 - - return torch.zeros((30 * whisper_config.sample_rate), dtype=values_dtype) - - -def dummy_data_for_whisper_audio(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_audios = mm_counts["audio"] - max_tokens_per_audio = get_max_whisper_audio_audio_tokens(ctx) - max_llm_audio_tokens = max_tokens_per_audio * num_audios - # if seq_len - max_llm_audio_tokens - 2 < 0: - # raise RuntimeError( - # f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " - # "please increase max_model_len or reduce audio limit by " - # "--limit-mm-per-prompt.") - - audio_token_index = 0 # ctx.model_config.hf_config.audio_token_index - - dummy_seqdata = SequenceData.from_prompt_token_counts((0, seq_len)) - # dummy_seqdata = SequenceData.from_prompt_token_counts( - # (audio_token_index, max_llm_audio_tokens), - # (0, seq_len - max_llm_audio_tokens), - # ) - dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) - print("dummy_audio", dummy_audio.shape) +def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + assert mm_counts["audio"] == 1 + sample_rate = 16000 return DummyData( - dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { - "audio": - consecutive_placeholder_ranges(num_items=num_audios, - item_size=max_tokens_per_audio) - }) + SequenceData.from_prompt_token_counts((0, seq_len)), + {"audio": [(np.zeros(30 * sample_rate), sample_rate)]}, + ) +@lru_cache def get_whisper_processor( processor_name: str, *args, @@ -604,189 +575,46 @@ def get_whisper_processor( return processor -cached_get_whisper_processor = lru_cache(get_whisper_processor) - - -def get_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """Gets a processor for the given model name via HuggingFace. - - Derived from `vllm.transformers_utils.image_processor.get_image_processor`. - """ - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoProcessor - - try: - processor = AutoProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the processor. If the processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return processor - - -cached_get_processor = lru_cache(get_processor) - - -def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers - and the output length of the audio encoder - """ - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - print("_get_feat_extract_output_lengths", input_lengths, output_lengths) - return input_lengths, output_lengths - - -def get_max_whisper_audio_audio_tokens(ctx: InputContext) -> int: - max_source_position = ( - ctx.model_config.hf_config.max_source_positions) - output_lengths = (max_source_position - 2) // 2 + 1 - print("get_max_whisper_audio_audio_tokens", output_lengths) - return output_lengths - - -def input_processor_for_whisper_audio( - ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: - multi_modal_data = inputs.get("multi_modal_data") - print("input_processor_for_whisper_audio", multi_modal_data) - if multi_modal_data is None or "audio" not in multi_modal_data: - return inputs - - audios = multi_modal_data["audio"] - if not isinstance(audios, list): - audios = [audios] - - if len(audios) == 0: - return inputs - - processor = cached_get_whisper_processor(ctx.model_config.model) - print("audios", audios) - - whisper_data = processor( - audio, - sampling_rate = self.whisper_config.sample_rate, - return_tensors = 'pt', - ) - whisper_data = whisper_data.to(self.model_config.dtype).input_features[0] - resampled_audios = [ - librosa.resample(audio, - orig_sr=sampling_rate, - target_sr=processor.feature_extractor.sampling_rate) - for audio, sampling_rate in audios - ] - audio_input_lengths = np.array( - [min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios]) - - audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( - audio_input_lengths) - - audio_token_index = ctx.model_config.hf_config.audio_token_index - - input_ids = inputs['prompt_token_ids'] - - new_input_ids = [] - audio_num = input_ids.count(audio_token_index) - assert len(audio_input_lengths) == audio_num, \ - (f'The text input contains {audio_num} audio tokens, ' - f'but {len(audio_input_lengths)} audios provided') - start = 0 - for audio_idx in range(audio_num): - end = input_ids.index(audio_token_index, start) - new_input_ids.extend(input_ids[start:end]) # text part - - new_input_ids.extend([audio_token_index] * - audio_output_lengths[audio_idx]) - start = end + 1 - new_input_ids.extend(input_ids[start:]) - +def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: return token_inputs( - prompt_token_ids=new_input_ids, - prompt=inputs['prompt'], - multi_modal_data=multi_modal_data, + prompt_token_ids=inputs["decoder"]["prompt_token_ids"], + multi_modal_data=inputs["encoder"]["multi_modal_data"], ) -def input_mapper_for_whisper_audio( +def input_mapper_for_whisper( ctx: InputContext, multi_modal_data: Union[np.ndarray, List[np.ndarray]], ) -> MultiModalKwargs: - """Input mapper for Qwen2-Audio.""" print("input_mapper_for_whisper_audio", multi_modal_data) if not isinstance(multi_modal_data, list): multi_modal_data = [multi_modal_data] + assert len(multi_modal_data) == 1 + if len(multi_modal_data) == 0: return MultiModalKwargs() - processor = cached_get_whisper_processor(ctx.model_config.model) + processor = get_whisper_processor(ctx.model_config.model) + target_sampling_rate = processor.feature_extractor.sampling_rate resampled_audios = [ - processor( - audio, - sampling_rate=sampling_rate, - return_tensors='pt', - ).to(torch.float16).input_features[0] + librosa.resample(audio, orig_sr=sampling_rate, + target_sr=target_sampling_rate) for audio, sampling_rate in multi_modal_data ] - print([audio.shape for audio in resampled_audios]) - audio_feature_extractor = processor.feature_extractor - if audio_feature_extractor is None: - raise RuntimeError( - "No HuggingFace audio_feature_extractor is available " - "to process the audio object") + kwargs = processor(resampled_audios, sampling_rate=target_sampling_rate, + return_tensors="pt") + kwargs["input_features"] = kwargs["input_features"].squeeze(0) - try: - resampled_audios = [ - librosa.resample( - audio, - orig_sr=sampling_rate, - target_sr=processor.feature_extractor.sampling_rate) - for audio, sampling_rate in multi_modal_data - ] - print([audio.shape for audio in resampled_audios]) - batch_data = audio_feature_extractor(resampled_audios, - sampling_rate=16000, - return_attention_mask=True, - padding="max_length", - return_tensors="pt").data - batch_data["feature_attention_mask"] = batch_data.pop("attention_mask") - except Exception: - logger.error("Failed to process audio (%s)", multi_modal_data) - raise - batch_data["input_features"] = batch_data["input_features"].squeeze(dim=0) - print("input_mapper_for_whisper_audio", batch_data["input_features"].shape) - return MultiModalKwargs(batch_data) - - -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_whisper_audio) -@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("audio", - input_mapper_for_whisper_audio) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_whisper_audio_audio_tokens) + print("input_mapper_for_whisper_audio", kwargs["input_features"].shape) + return MultiModalKwargs(kwargs) + + +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) +@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -821,6 +649,7 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: + print("FORWARD", kwargs.keys()) decoder_outputs = self.model( input_features=kwargs["input_features"].to(torch.float16), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4529cf27ef565..61c365e4a87c3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -152,6 +152,7 @@ def uses_mrope(config: PretrainedConfig) -> bool: def is_encoder_decoder(config: PretrainedConfig) -> bool: + return False """Detect if the model with this config is used as an encoder/decoder.""" text_config = getattr(config, "text_config", None) if text_config is not None: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a368bb9ee9a5b..212a58080ec35 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -79,6 +79,7 @@ def __init__( ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner + print(ModelRunnerClass) self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, From 77ad7ed4a1bd9337bd2d4f406aab57132af71d82 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 23:11:41 +0000 Subject: [PATCH 05/44] kinda working with encoder decoder --- examples/offline_inference_audio_language.py | 13 +++++-------- vllm/core/scheduler.py | 8 ++++++-- vllm/engine/llm_engine.py | 4 ++++ vllm/model_executor/models/whisper.py | 5 +++++ vllm/transformers_utils/config.py | 1 - vllm/worker/model_runner.py | 2 ++ vllm/worker/worker.py | 1 - vllm/worker/worker_base.py | 1 + 8 files changed, 23 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 91e5f994db32f..b6cc0b3af49d1 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -71,15 +71,12 @@ def run_whisper(question: str, audio_count: int): enforce_eager=True, limit_mm_per_prompt={"audio": audio_count}) - audio_in_prompt = "".join([ - f"Audio {idx+1}: " - f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) - ]) + # audio_in_prompt = "".join([ + # f"Audio {idx+1}: " + # f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) + # ]) - prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - "<|im_start|>user\n" - f"{audio_in_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = "<|startoftranscript|>" stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3bc6becf0995..8e36c8a9b0dff 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -434,6 +434,7 @@ def num_decoding_tokens_per_seq(self) -> int: def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. + print("ADD", seq_group) self.waiting.append(seq_group) def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: @@ -1318,6 +1319,7 @@ def schedule( if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup encoder_seq = seq_group.get_encoder_seq() + print("encoder_seq", seq_group.encoder_seq.inputs.inputs) assert encoder_seq is not None encoder_seq_data = encoder_seq.data # Block table for cross-attention @@ -1325,6 +1327,7 @@ def schedule( cross_block_table = self.block_manager.get_cross_block_table( seq_group) else: + print("NOT encoder_seq") encoder_seq_data = None cross_block_table = None @@ -1362,6 +1365,7 @@ def schedule( # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: + print("SCHEDULER SEQGROUP", seq_group.multi_modal_data) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, @@ -1381,8 +1385,8 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_data=seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data + ,#if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9be30c635cb2c..5254966e86240 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -625,6 +625,7 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + print("PROCESSED INPUTS", processed_inputs) if is_encoder_decoder_inputs(processed_inputs): decoder_inputs = processed_inputs["decoder"] encoder_inputs = processed_inputs["encoder"] @@ -638,6 +639,8 @@ def _add_processed_request( encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)) + + print("ENCODER_SEQ", encoder_seq.inputs.inputs) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -1398,6 +1401,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + print("STEP", execute_model_req) if allow_async_output_proc: execute_model_req.async_callback = self.async_callbacks[ diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ba84f60d4f33d..25c14a57a8c5f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -150,6 +150,7 @@ def forward( q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) + print(q.shape, k.shape, v.shape, hidden_states.shape) q = self._shape(q, -1, 1) k = self._shape(k, -1, 1) @@ -533,6 +534,7 @@ def forward( def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): + print("DUMMY DATA") assert mm_counts["audio"] == 1 sample_rate = 16000 return DummyData( @@ -576,6 +578,8 @@ def get_whisper_processor( def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + print("input_processor_for_whisper", inputs) + return inputs return token_inputs( prompt_token_ids=inputs["decoder"]["prompt_token_ids"], multi_modal_data=inputs["encoder"]["multi_modal_data"], @@ -672,6 +676,7 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) + print("SAMPLE", next_tokens) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 61c365e4a87c3..4529cf27ef565 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -152,7 +152,6 @@ def uses_mrope(config: PretrainedConfig) -> bool: def is_encoder_decoder(config: PretrainedConfig) -> bool: - return False """Detect if the model with this config is used as an encoder/decoder.""" text_config = getattr(config, "text_config", None) if text_config is not None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 26fd486130ce6..48b5c4cc420d6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -729,6 +729,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): encoder_seq_len = 0 if self.runner.model_config.is_encoder_decoder: + print(seq_group_metadata) encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() inter_data = self.init_cached_inter_data( @@ -1282,6 +1283,7 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = min(448, seq_len) batch_size += seq_len dummy_data = self.input_registry \ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 212a58080ec35..a368bb9ee9a5b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -79,7 +79,6 @@ def __init__( ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner - print(ModelRunnerClass) self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6d00102e0a324..68dd4011a0d0a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -262,6 +262,7 @@ def _get_driver_input_and_broadcast( ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker + print("DRIVER", execute_model_req) worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) From 755086b68a14d6ebec181cc6e279499239b6d2ae Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 23:12:31 +0000 Subject: [PATCH 06/44] add whisper example --- examples/offline_inference_whisper.py | 62 +++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 examples/offline_inference_whisper.py diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py new file mode 100644 index 0000000000000..5299ac29ecbe9 --- /dev/null +++ b/examples/offline_inference_whisper.py @@ -0,0 +1,62 @@ +''' +Demonstrate prompting of text-to-text +encoder/decoder models, specifically BART +''' + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt + +audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + +dtype = "float" + +# Create a Whisper encoder/decoder model instance +llm = LLM( + model="openai/whisper-large-v3", + max_model_len=448, + max_num_seqs=1, + enforce_eager=True, + limit_mm_per_prompt={"audio": 1} +) + +prompts = [ + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt="", + multi_modal_data={"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate} + ), + decoder_prompt="<|startoftranscript|>", + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt="", + multi_modal_data={"audio": AudioAsset("winning_call").audio_and_sample_rate} + ), + decoder_prompt="<|startoftranscript|>", + ), +] + +print(prompts) + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, +) + +# Generate output tokens from the prompts. The output is a list of +# RequestOutput objects that contain the prompt, generated +# text, and other information. +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") From b38f5b789bba994d9824539ff5d825bb977702a6 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 12 Dec 2024 23:17:38 +0000 Subject: [PATCH 07/44] update --- examples/offline_inference_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 5299ac29ecbe9..616e342ecb93b 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -44,7 +44,7 @@ temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=200, ) # Generate output tokens from the prompts. The output is a list of From ff70bce5350630b4b1954cd5f593492ed96b7379 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 13 Dec 2024 00:49:03 +0000 Subject: [PATCH 08/44] cleanup a bit --- examples/offline_inference_whisper.py | 11 ++++++++--- vllm/core/scheduler.py | 6 +----- vllm/engine/llm_engine.py | 4 ---- vllm/model_executor/models/whisper.py | 16 +--------------- vllm/worker/model_runner.py | 1 - vllm/worker/worker_base.py | 1 - 6 files changed, 10 insertions(+), 29 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 616e342ecb93b..91d69b75c4472 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -2,6 +2,7 @@ Demonstrate prompting of text-to-text encoder/decoder models, specifically BART ''' +import time from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset @@ -35,9 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] - -print(prompts) +] * 10 # Create a sampling params object. sampling_params = SamplingParams( @@ -47,6 +46,8 @@ max_tokens=200, ) +start = time.time() + # Generate output tokens from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. @@ -60,3 +61,7 @@ print(f"Encoder prompt: {encoder_prompt!r}, " f"Decoder prompt: {prompt!r}, " f"Generated text: {generated_text!r}") + +duration = time.time() - start +print("Duration:", duration) +print("RPS:", len(prompts) / duration) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8e36c8a9b0dff..65eaf73d67e56 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -434,7 +434,6 @@ def num_decoding_tokens_per_seq(self) -> int: def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. - print("ADD", seq_group) self.waiting.append(seq_group) def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: @@ -1319,7 +1318,6 @@ def schedule( if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup encoder_seq = seq_group.get_encoder_seq() - print("encoder_seq", seq_group.encoder_seq.inputs.inputs) assert encoder_seq is not None encoder_seq_data = encoder_seq.data # Block table for cross-attention @@ -1327,7 +1325,6 @@ def schedule( cross_block_table = self.block_manager.get_cross_block_table( seq_group) else: - print("NOT encoder_seq") encoder_seq_data = None cross_block_table = None @@ -1365,7 +1362,6 @@ def schedule( # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: - print("SCHEDULER SEQGROUP", seq_group.multi_modal_data) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, @@ -1385,7 +1381,7 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data + multi_modal_data=(seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data) ,#if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5254966e86240..9be30c635cb2c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -625,7 +625,6 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - print("PROCESSED INPUTS", processed_inputs) if is_encoder_decoder_inputs(processed_inputs): decoder_inputs = processed_inputs["decoder"] encoder_inputs = processed_inputs["encoder"] @@ -639,8 +638,6 @@ def _add_processed_request( encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)) - - print("ENCODER_SEQ", encoder_seq.inputs.inputs) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -1401,7 +1398,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) - print("STEP", execute_model_req) if allow_async_output_proc: execute_model_req.async_callback = self.async_callbacks[ diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 25c14a57a8c5f..4b09006383359 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -150,7 +150,6 @@ def forward( q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - print(q.shape, k.shape, v.shape, hidden_states.shape) q = self._shape(q, -1, 1) k = self._shape(k, -1, 1) @@ -442,7 +441,6 @@ def forward( self, input_features, ): - print(self.conv1.weight.dtype, self.conv1.bias.dtype, input_features.dtype) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) @@ -487,7 +485,6 @@ def forward( attn_metadata: AttentionMetadata, past_key_values = None, ): - print(self.max_target_positions, positions, input_ids.shape, positions.shape) inputs_embeds = self.embed_tokens(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions @@ -534,7 +531,6 @@ def forward( def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): - print("DUMMY DATA") assert mm_counts["audio"] == 1 sample_rate = 16000 return DummyData( @@ -578,19 +574,13 @@ def get_whisper_processor( def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: - print("input_processor_for_whisper", inputs) return inputs - return token_inputs( - prompt_token_ids=inputs["decoder"]["prompt_token_ids"], - multi_modal_data=inputs["encoder"]["multi_modal_data"], - ) def input_mapper_for_whisper( ctx: InputContext, multi_modal_data: Union[np.ndarray, List[np.ndarray]], ) -> MultiModalKwargs: - print("input_mapper_for_whisper_audio", multi_modal_data) if not isinstance(multi_modal_data, list): multi_modal_data = [multi_modal_data] @@ -612,7 +602,6 @@ def input_mapper_for_whisper( return_tensors="pt") kwargs["input_features"] = kwargs["input_features"].squeeze(0) - print("input_mapper_for_whisper_audio", kwargs["input_features"].shape) return MultiModalKwargs(kwargs) @@ -652,9 +641,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs, - ) -> torch.Tensor: - print("FORWARD", kwargs.keys()) - + ) -> torch.Tensor: decoder_outputs = self.model( input_features=kwargs["input_features"].to(torch.float16), input_ids=input_ids, @@ -676,7 +663,6 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) - print("SAMPLE", next_tokens) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 48b5c4cc420d6..d8554c0639eb8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -729,7 +729,6 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): encoder_seq_len = 0 if self.runner.model_config.is_encoder_decoder: - print(seq_group_metadata) encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() inter_data = self.init_cached_inter_data( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 68dd4011a0d0a..6d00102e0a324 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -262,7 +262,6 @@ def _get_driver_input_and_broadcast( ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker - print("DRIVER", execute_model_req) worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) From 3fbd0671d985daf4e44990325d817157df336d98 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 13 Dec 2024 02:43:13 +0000 Subject: [PATCH 09/44] batching --- examples/offline_inference_whisper.py | 4 +- vllm/model_executor/models/whisper.py | 76 +++++++++++++++++---------- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 91d69b75c4472..f4b4ef6b67b36 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -16,7 +16,7 @@ llm = LLM( model="openai/whisper-large-v3", max_model_len=448, - max_num_seqs=1, + max_num_seqs=64, enforce_eager=True, limit_mm_per_prompt={"audio": 1} ) @@ -36,7 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] * 10 +] * 1000 # Create a sampling params object. sampling_params = SamplingParams( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 4b09006383359..5726505699af1 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -30,6 +30,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import SequenceData +from vllm.vllm_flash_attn import flash_attn_func from xformers import ops as xops from .interfaces import SupportsMultiModal @@ -142,18 +143,19 @@ def forward( self, hidden_states: torch.Tensor, ): - sizes = hidden_states.size() - if len(sizes) == 3: - bsz, tgt_len, _ = sizes - else: - tgt_len, _ = sizes + bsz, seq = hidden_states.size()[:2] + q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - q = self._shape(q, -1, 1) - k = self._shape(k, -1, 1) - v = self._shape(v, -1, 1) + # q = self._shape(q, -1, 1) + # k = self._shape(k, -1, 1) + # v = self._shape(v, -1, 1) + + q = q.view(bsz, seq, self.num_heads, self.head_dim) + k = k.view(bsz, seq, self.num_heads, self.head_dim) + v = v.view(bsz, seq, self.num_heads, self.head_dim) attn_output = xops.memory_efficient_attention_forward( q, @@ -164,8 +166,19 @@ def forward( scale=None, op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], ) - - attn_output = attn_output.reshape(-1, self.embed_dim) + + # attn_output = flash_attn_func( + # q=q, + # k=k, + # v=v, + # softmax_scale=None, + # causal=False, + # window_size=(-1, -1), + # alibi_slopes=None, + # softcap=0, + # ) + + attn_output = attn_output.reshape(bsz, seq, self.embed_dim) output, _ = self.out_proj(attn_output) return output @@ -204,12 +217,6 @@ def forward( kv_cache: torch.Tensor = None, attn_metadata: AttentionMetadata = None, ): - sizes = hidden_states.size() - if len(sizes) == 3: - bsz, tgt_len, _ = sizes - else: - tgt_len, _ = sizes - q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) @@ -245,19 +252,24 @@ def forward( encoder_hidden_states: torch.Tensor, attn_metadata: AttentionMetadata = None, ): - sizes = hidden_states.size() - if len(sizes) == 3: - bsz, tgt_len, _ = sizes - else: - tgt_len, _ = sizes + # HACK + query_lens = attn_metadata.query_start_loc.diff().tolist() + hidden_states = list(hidden_states.split(query_lens)) + padded_size = max(query_lens) + for i in range(len(hidden_states)): + hidden_states[i] = torch.nn.functional.pad(hidden_states[i], (0, 0, 0, padded_size - hidden_states[i].size(0))) + hidden_states = torch.stack(hidden_states, dim=0) + + bsz, seq = hidden_states.size()[:2] + bsz2, seq2 = encoder_hidden_states.size()[:2] q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(encoder_hidden_states) v, _ = self.v_proj(encoder_hidden_states) - q = self._shape(q, -1, 1) - k = self._shape(k, -1, 1) - v = self._shape(v, -1, 1) + q = q.view(bsz, seq, self.num_heads, self.head_dim) + k = k.view(bsz2, seq2, self.num_heads, self.head_dim) + v = v.view(bsz2, seq2, self.num_heads, self.head_dim) attn_output = xops.memory_efficient_attention_forward( q, @@ -269,6 +281,12 @@ def forward( op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], ) + # HACK + attn_output = list(torch.unbind(attn_output)) + for i in range(len(attn_output)): + attn_output[i] = attn_output[i][:query_lens[i], :] + attn_output = torch.cat(attn_output, dim=0) + attn_output = attn_output.reshape(-1, self.embed_dim) output, _ = self.out_proj(attn_output) return output @@ -448,7 +466,7 @@ def forward( embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos - for idx, encoder_layer in enumerate(self.layers): + for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) @@ -481,9 +499,8 @@ def forward( input_ids, positions: torch.Tensor, encoder_hidden_states: torch.Tensor, - kv_caches: List[torch.Tensor], + kv_caches: torch.Tensor, attn_metadata: AttentionMetadata, - past_key_values = None, ): inputs_embeds = self.embed_tokens(input_ids) positions = self.embed_positions(positions) @@ -529,6 +546,10 @@ def forward( return decoder_outputs +def get_max_whisper_audio_tokens(ctx: InputContext) -> int: + return ctx.model_config.hf_config.max_source_positions + + def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): assert mm_counts["audio"] == 1 @@ -605,6 +626,7 @@ def input_mapper_for_whisper( return MultiModalKwargs(kwargs) +#@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens) @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) @INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) From 9032aa14720a29bf18cba2b2b5a1efbc06e7e797 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 13 Dec 2024 03:06:29 +0000 Subject: [PATCH 10/44] flash_attn --- examples/offline_inference_whisper.py | 2 +- vllm/model_executor/models/whisper.py | 50 ++++++++++----------------- 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index f4b4ef6b67b36..b46858450296c 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -36,7 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] * 1000 +] * 128 # Create a sampling params object. sampling_params = SamplingParams( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 5726505699af1..d462a4e701cb6 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,6 @@ from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import SequenceData from vllm.vllm_flash_attn import flash_attn_func -from xformers import ops as xops from .interfaces import SupportsMultiModal from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -149,35 +148,21 @@ def forward( k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - # q = self._shape(q, -1, 1) - # k = self._shape(k, -1, 1) - # v = self._shape(v, -1, 1) - q = q.view(bsz, seq, self.num_heads, self.head_dim) k = k.view(bsz, seq, self.num_heads, self.head_dim) v = v.view(bsz, seq, self.num_heads, self.head_dim) - attn_output = xops.memory_efficient_attention_forward( - q, - k, - v, - attn_bias=None, - p=0.0, - scale=None, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], + attn_output = flash_attn_func( + q=q, + k=k, + v=v, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + softcap=0, ) - # attn_output = flash_attn_func( - # q=q, - # k=k, - # v=v, - # softmax_scale=None, - # causal=False, - # window_size=(-1, -1), - # alibi_slopes=None, - # softcap=0, - # ) - attn_output = attn_output.reshape(bsz, seq, self.embed_dim) output, _ = self.out_proj(attn_output) return output @@ -271,14 +256,15 @@ def forward( k = k.view(bsz2, seq2, self.num_heads, self.head_dim) v = v.view(bsz2, seq2, self.num_heads, self.head_dim) - attn_output = xops.memory_efficient_attention_forward( - q, - k, - v, - attn_bias=None, - p=0.0, - scale=None, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0], + attn_output = flash_attn_func( + q=q, + k=k, + v=v, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + softcap=0, ) # HACK From ce3a87cabd7bdcad5e3fde06d7bf799b852d1548 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 13 Dec 2024 04:12:25 +0000 Subject: [PATCH 11/44] WIP (broken) --- examples/offline_inference_whisper.py | 2 +- vllm/attention/backends/flash_attn.py | 6 +- vllm/model_executor/models/whisper.py | 86 +++++++++++++++++---------- 3 files changed, 62 insertions(+), 32 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index b46858450296c..0c91ebe7c9877 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -36,7 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] * 128 +] #* 128 # Create a sampling params object. sampling_params = SamplingParams( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..143e8213c4c08 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -712,13 +712,17 @@ def forward( (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + print("ATTN_METADATA", attn_metadata) decode_query = query[num_prefill_query_tokens:] decode_output = output[num_prefill_query_tokens:] # QKV for prefill. query = query[:num_prefill_query_tokens] prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens, ( + f"decode_query.shape: {decode_query.shape}, " + f"num_decode_query_tokens: {num_decode_query_tokens}" + ) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d462a4e701cb6..f2517c0b8ffd0 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -8,7 +8,7 @@ from torch import nn from transformers import WhisperConfig, WhisperProcessor -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -137,34 +137,49 @@ def __init__( cache_config=cache_config, prefix=prefix, ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ): bsz, seq = hidden_states.size()[:2] q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + attn_type=AttentionType.ENCODER) - q = q.view(bsz, seq, self.num_heads, self.head_dim) - k = k.view(bsz, seq, self.num_heads, self.head_dim) - v = v.view(bsz, seq, self.num_heads, self.head_dim) + output, _ = self.out_proj(attn_output) - attn_output = flash_attn_func( - q=q, - k=k, - v=v, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - softcap=0, - ) + # q = q.view(bsz, seq, self.num_heads, self.head_dim) + # k = k.view(bsz, seq, self.num_heads, self.head_dim) + # v = v.view(bsz, seq, self.num_heads, self.head_dim) + + # attn_output = flash_attn_func( + # q=q, + # k=k, + # v=v, + # softmax_scale=None, + # causal=False, + # window_size=(-1, -1), + # alibi_slopes=None, + # softcap=0, + # ) - attn_output = attn_output.reshape(bsz, seq, self.embed_dim) - output, _ = self.out_proj(attn_output) + # attn_output = attn_output.reshape(bsz, seq, self.embed_dim) + # output, _ = self.out_proj(attn_output) return output @@ -315,11 +330,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states residual = hidden_states @@ -444,16 +463,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, input_features, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ): inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) + print("INPUTS EMBEDS", inputs_embeds.size()) embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states) + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -485,7 +511,7 @@ def forward( input_ids, positions: torch.Tensor, encoder_hidden_states: torch.Tensor, - kv_caches: torch.Tensor, + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ): inputs_embeds = self.embed_tokens(input_ids) @@ -520,7 +546,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - encoder_outputs = self.encoder(input_features) + encoder_outputs = self.encoder( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) decoder_outputs = self.decoder( input_ids=input_ids, @@ -581,6 +611,7 @@ def get_whisper_processor( def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + inputs["encoder"]["prompt_token_ids"] = [0] * ctx.model_config.hf_config.max_source_positions return inputs @@ -612,10 +643,10 @@ def input_mapper_for_whisper( return MultiModalKwargs(kwargs) -#@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens) @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) @INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens) class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -627,15 +658,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size - # self.proj_out = RowParallelLinear( - # input_size = config.d_model, - # output_size = config.vocab_size, - # bias = False, - # quant_config=quant_config, - # ) self.proj_out = ParallelLMHead(config.vocab_size, - config.d_model, - quant_config=quant_config) + config.d_model, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -649,7 +674,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs, - ) -> torch.Tensor: + ) -> torch.Tensor: + print(attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_start_loc) decoder_outputs = self.model( input_features=kwargs["input_features"].to(torch.float16), input_ids=input_ids, From 04a0ef44423878ffed03b27d102a07ca419de77c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 16 Dec 2024 20:32:16 +0000 Subject: [PATCH 12/44] WIP --- vllm/attention/backends/flash_attn.py | 4 +++- vllm/model_executor/models/whisper.py | 4 +++- vllm/worker/enc_dec_model_runner.py | 11 +++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 143e8213c4c08..5ce8ce6284a75 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -712,7 +712,6 @@ def forward( (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - print("ATTN_METADATA", attn_metadata) decode_query = query[num_prefill_query_tokens:] decode_output = output[num_prefill_query_tokens:] # QKV for prefill. @@ -775,7 +774,10 @@ def forward( out=prefill_output, ) + print("METADATA", attn_metadata) + if decode_meta := attn_metadata.decode_metadata: + print("DECODE_META", decode_meta) # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index f2517c0b8ffd0..6070065f10eda 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -340,6 +340,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, ) + hidden_states = hidden_states.view(residual.size()) ## HACK hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -570,8 +571,9 @@ def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): assert mm_counts["audio"] == 1 sample_rate = 16000 + num_tokens = ctx.model_config.hf_config.max_source_positions return DummyData( - SequenceData.from_prompt_token_counts((0, seq_len)), + SequenceData.from_prompt_token_counts((0, num_tokens)), {"audio": [(np.zeros(30 * sample_rate), sample_rate)]}, ) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 2916f92fff5eb..4762ff18fd001 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -288,12 +288,11 @@ def profile_run(self) -> None: seq_len, self.mm_registry, is_encoder_data=False) - encoder_dummy_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) + encoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine assert len( From fd4ed14ae04b0fd998fa3da46bc20561ebbebeb0 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 16 Dec 2024 21:05:37 +0000 Subject: [PATCH 13/44] 13rps --- examples/offline_inference_whisper.py | 2 +- vllm/attention/backends/flash_attn.py | 3 -- vllm/core/scheduler.py | 2 +- vllm/model_executor/models/whisper.py | 62 ++++++++++++++++++++------- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 0c91ebe7c9877..b46858450296c 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -36,7 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] #* 128 +] * 128 # Create a sampling params object. sampling_params = SamplingParams( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5ce8ce6284a75..e08904c829f73 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -774,10 +774,7 @@ def forward( out=prefill_output, ) - print("METADATA", attn_metadata) - if decode_meta := attn_metadata.decode_metadata: - print("DECODE_META", decode_meta) # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 65eaf73d67e56..c031d56487139 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1382,7 +1382,7 @@ def schedule( # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=(seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data) - ,#if scheduler_outputs.num_prefill_groups > 0 else None, + if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 6070065f10eda..50ed48b8851bc 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -245,13 +245,39 @@ def __init__( cache_config=cache_config, prefix=prefix, ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata = None, + encoder_hidden_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ): + q, _ = self.q_proj(hidden_states) + + if encoder_hidden_states is not None: + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + else: + k = v = None + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output, _ = self.out_proj(attn_output) + + return output + + output, _ = self.out_proj(attn_output) # HACK query_lens = attn_metadata.query_start_loc.diff().tolist() hidden_states = list(hidden_states.split(query_lens)) @@ -404,7 +430,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ): @@ -422,6 +448,7 @@ def forward( hidden_states = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, + kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -470,7 +497,6 @@ def forward( inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) - print("INPUTS EMBEDS", inputs_embeds.size()) embed_pos = self.embed_positions.weight @@ -511,7 +537,7 @@ def forward( self, input_ids, positions: torch.Tensor, - encoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ): @@ -524,7 +550,7 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, kv_cache=kv_caches[idx], - attn_metadata=attn_metadata + attn_metadata=attn_metadata, ) hidden_states = self.layer_norm(hidden_states) @@ -541,17 +567,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features: torch.FloatTensor, + input_features: Optional[torch.FloatTensor], input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - encoder_outputs = self.encoder( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + ) -> torch.Tensor: + if input_features is not None: + # Prefill encoder kv-caches + encoder_outputs = self.encoder( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + else: + encoder_outputs = None decoder_outputs = self.decoder( input_ids=input_ids, @@ -677,9 +707,11 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: - print(attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_start_loc) + input_features = kwargs.get("input_features") + if input_features is not None: + input_features = input_features.to(torch.float16) decoder_outputs = self.model( - input_features=kwargs["input_features"].to(torch.float16), + input_features=input_features, input_ids=input_ids, positions=positions, kv_caches=kv_caches, From 26cfede0313ddf7016962dab3ad05424634fa927 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 16 Dec 2024 22:05:57 +0000 Subject: [PATCH 14/44] fuse qkv --- vllm/model_executor/models/whisper.py | 305 ++++++++++---------------- 1 file changed, 115 insertions(+), 190 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 50ed48b8851bc..9fac38dc524df 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,6 +1,6 @@ import math from functools import lru_cache -from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union +from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union import librosa import numpy as np @@ -33,9 +33,7 @@ from vllm.vllm_flash_attn import flash_attn_func from .interfaces import SupportsMultiModal -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import AutoWeightsLoader, make_layers, maybe_prefix logger = init_logger(__name__) @@ -67,16 +65,30 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, + attn_type: AttentionType = AttentionType.DECODER, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() - tp_size = get_tensor_model_parallel_world_size() self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_kv_heads = max(1, self.num_heads // tp_size) - self.head_dim = embed_dim // num_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.attn_type = attn_type if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -85,27 +97,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 - self.k_proj = RowParallelLinear( - input_size = embed_dim, - output_size = embed_dim, - bias = False, - quant_config=quant_config, - prefix=f"{prefix}.k_proj", - ) - self.v_proj = RowParallelLinear( - input_size = embed_dim, - output_size = embed_dim, - bias = bias, - quant_config=quant_config, - prefix=f"{prefix}.v_proj", - ) - self.q_proj = RowParallelLinear( - input_size = embed_dim, - output_size = embed_dim, - bias = bias, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - ) + self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) self.out_proj = RowParallelLinear( input_size = embed_dim, output_size = embed_dim, @@ -113,30 +105,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() - - -class WhisperEncoderAttention(WhisperAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - ): - super().__init__( - embed_dim=embed_dim, - num_heads=num_heads, - bias=bias, - quant_config=quant_config, - cache_config=cache_config, - prefix=prefix, - ) self.attn = Attention( self.num_heads, self.head_dim, @@ -146,113 +114,80 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", ) - + + def _init_qkv(self, + embed_dim: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + def forward( self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ): - bsz, seq = hidden_states.size()[:2] + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - attn_type=AttentionType.ENCODER) + attn_type=self.attn_type) output, _ = self.out_proj(attn_output) - # q = q.view(bsz, seq, self.num_heads, self.head_dim) - # k = k.view(bsz, seq, self.num_heads, self.head_dim) - # v = v.view(bsz, seq, self.num_heads, self.head_dim) - - # attn_output = flash_attn_func( - # q=q, - # k=k, - # v=v, - # softmax_scale=None, - # causal=False, - # window_size=(-1, -1), - # alibi_slopes=None, - # softcap=0, - # ) - - # attn_output = attn_output.reshape(bsz, seq, self.embed_dim) - # output, _ = self.out_proj(attn_output) return output -class WhisperDecoderAttention(WhisperAttention): +class WhisperCrossAttention(WhisperAttention): def __init__( self, embed_dim: int, num_heads: int, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__( embed_dim=embed_dim, num_heads=num_heads, bias=bias, - quant_config=quant_config, - cache_config=cache_config, - prefix=prefix, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn", + prefix=prefix, ) - - def forward( - self, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor = None, - attn_metadata: AttentionMetadata = None, - ): - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - - output, _ = self.out_proj(attn_output) - - return output - -class WhisperDecoderCrossAttention(WhisperAttention): - def __init__( - self, + def _init_qkv(self, embed_dim: int, - num_heads: int, bias: bool = True, quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, prefix: str = "", - ): - super().__init__( - embed_dim=embed_dim, - num_heads=num_heads, - bias=bias, + ) -> None: + self.q_proj = RowParallelLinear( + input_size = embed_dim, + output_size = embed_dim, + bias = bias, quant_config=quant_config, - cache_config=cache_config, - prefix=prefix, + prefix=f"{prefix}.q_proj", ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, + self.kv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=0, + total_num_kv_heads=self.total_num_heads, + bias=bias, quant_config=quant_config, - prefix=f"{prefix}.attn", + prefix=f"{prefix}.kv_proj", ) def forward( @@ -265,8 +200,8 @@ def forward( q, _ = self.q_proj(hidden_states) if encoder_hidden_states is not None: - k, _ = self.k_proj(encoder_hidden_states) - v, _ = self.v_proj(encoder_hidden_states) + kv, _ = self.kv_proj(encoder_hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) else: k = v = None @@ -277,62 +212,22 @@ def forward( return output - output, _ = self.out_proj(attn_output) - # HACK - query_lens = attn_metadata.query_start_loc.diff().tolist() - hidden_states = list(hidden_states.split(query_lens)) - padded_size = max(query_lens) - for i in range(len(hidden_states)): - hidden_states[i] = torch.nn.functional.pad(hidden_states[i], (0, 0, 0, padded_size - hidden_states[i].size(0))) - hidden_states = torch.stack(hidden_states, dim=0) - - bsz, seq = hidden_states.size()[:2] - bsz2, seq2 = encoder_hidden_states.size()[:2] - - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(encoder_hidden_states) - v, _ = self.v_proj(encoder_hidden_states) - - q = q.view(bsz, seq, self.num_heads, self.head_dim) - k = k.view(bsz2, seq2, self.num_heads, self.head_dim) - v = v.view(bsz2, seq2, self.num_heads, self.head_dim) - - attn_output = flash_attn_func( - q=q, - k=k, - v=v, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - softcap=0, - ) - - # HACK - attn_output = list(torch.unbind(attn_output)) - for i in range(len(attn_output)): - attn_output[i] = attn_output[i][:query_lens[i], :] - attn_output = torch.cat(attn_output, dim=0) - - attn_output = attn_output.reshape(-1, self.embed_dim) - output, _ = self.out_proj(attn_output) - return output - class WhisperEncoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.embed_dim = config.d_model - self.self_attn = WhisperEncoderAttention( + self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, - quant_config=quant_config, + attn_type=AttentionType.ENCODER, cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -389,25 +284,26 @@ class WhisperDecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.embed_dim = config.d_model - self.self_attn = WhisperDecoderAttention( + self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - quant_config=quant_config, + attn_type=AttentionType.DECODER, cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.activation_fn = FastGELU() self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WhisperDecoderCrossAttention( + self.encoder_attn = WhisperCrossAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - quant_config=quant_config, cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.encoder_attn", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -592,6 +488,43 @@ def forward( ) return decoder_outputs + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"), + (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + def get_max_whisper_audio_tokens(ctx: InputContext) -> int: return ctx.model_config.hf_config.max_source_positions @@ -693,6 +626,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.proj_out = ParallelLMHead(config.vocab_size, config.d_model, quant_config=quant_config) + self.proj_out = self.proj_out.tie_weights( + self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -733,18 +668,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - if name == 'model.decoder.embed_tokens.weight': - param = params_dict['proj_out.weight'] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - param = params_dict[name] + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) + return loader.load_weights((name, loaded_weight) + for name, loaded_weight in weights) \ No newline at end of file From 34c5830377626379e4a419c9fc4b1420fbe11558 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 16 Dec 2024 22:42:25 +0000 Subject: [PATCH 15/44] clean --- vllm/model_executor/models/whisper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 9fac38dc524df..d37dcf8fdbc20 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -261,7 +261,6 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, ) - hidden_states = hidden_states.view(residual.size()) ## HACK hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -393,17 +392,18 @@ def forward( inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) - + embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) - + hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -653,13 +653,13 @@ def forward( attn_metadata=attn_metadata, ) return decoder_outputs - + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.proj_out, hidden_states, sampling_metadata) return logits - + def sample( self, logits: torch.Tensor, From bf111b2db803793b395b6a0181c6d70bc7678cf6 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 16 Dec 2024 23:28:22 +0000 Subject: [PATCH 16/44] 20 RPS --- examples/offline_inference_whisper.py | 7 ++++--- vllm/model_executor/models/whisper.py | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index b46858450296c..45e50382472fa 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -16,8 +16,8 @@ llm = LLM( model="openai/whisper-large-v3", max_model_len=448, - max_num_seqs=64, - enforce_eager=True, + max_num_seqs=128, + #max_num_batched_tokens=16384, limit_mm_per_prompt={"audio": 1} ) @@ -36,7 +36,7 @@ ), decoder_prompt="<|startoftranscript|>", ), -] * 128 +] * 1024 # Create a sampling params object. sampling_params = SamplingParams( @@ -63,5 +63,6 @@ f"Generated text: {generated_text!r}") duration = time.time() - start + print("Duration:", duration) print("RPS:", len(prompts) / duration) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d37dcf8fdbc20..3139c72ffb42a 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -30,7 +30,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import SequenceData -from vllm.vllm_flash_attn import flash_attn_func from .interfaces import SupportsMultiModal from .utils import AutoWeightsLoader, make_layers, maybe_prefix @@ -635,7 +634,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - #whisper_data: torch.Tensor, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], From a21470bddce2617d4ee625f8d14f6a475e71e64a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 02:37:03 +0000 Subject: [PATCH 17/44] 26rps --- examples/offline_inference_whisper.py | 15 +++++++++------ vllm/core/scheduler.py | 1 + 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 45e50382472fa..42ac1d71ff4a8 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -16,9 +16,9 @@ llm = LLM( model="openai/whisper-large-v3", max_model_len=448, - max_num_seqs=128, - #max_num_batched_tokens=16384, - limit_mm_per_prompt={"audio": 1} + max_num_seqs=400, + limit_mm_per_prompt={"audio": 1}, + kv_cache_dtype="fp8", ) prompts = [ @@ -44,6 +44,9 @@ top_p=1.0, min_tokens=0, max_tokens=200, + # min_tokens=40, + # max_tokens=40, + # ignore_eos=True, ) start = time.time() @@ -58,9 +61,9 @@ prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - print(f"Encoder prompt: {encoder_prompt!r}, " - f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + # print(f"Encoder prompt: {encoder_prompt!r}, " + # f"Decoder prompt: {prompt!r}, " + # f"Generated text: {generated_text!r}") duration = time.time() - start diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c031d56487139..04dfb0a261bc7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1579,6 +1579,7 @@ def _preempt_by_recompute( seq.status = SequenceStatus.WAITING self.free_seq(seq) seq.reset_state_for_recompute() + self._free_seq_group_cross_attn_blocks(seq_group) def _preempt_by_swap( self, From b457c016756372e04d7866b2e6cd07f222908643 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 04:54:15 +0000 Subject: [PATCH 18/44] 41 rps --- examples/offline_inference_whisper.py | 10 +-- vllm/model_executor/models/whisper.py | 66 +++++++++++++------ .../tokenizer_group/tokenizer_group.py | 2 + 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 42ac1d71ff4a8..8b040851ce120 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -27,14 +27,14 @@ prompt="", multi_modal_data={"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate} ), - decoder_prompt="<|startoftranscript|>", + decoder_prompt="", ), ExplicitEncoderDecoderPrompt( encoder_prompt=TextPrompt( prompt="", multi_modal_data={"audio": AudioAsset("winning_call").audio_and_sample_rate} ), - decoder_prompt="<|startoftranscript|>", + decoder_prompt="", ), ] * 1024 @@ -61,9 +61,9 @@ prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - # print(f"Encoder prompt: {encoder_prompt!r}, " - # f"Decoder prompt: {prompt!r}, " - # f"Generated text: {generated_text!r}") + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") duration = time.time() - start diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 3139c72ffb42a..31c552baa5785 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -384,18 +384,35 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features, + input_features: Union[torch.Tensor, List[torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ): - inputs_embeds = nn.functional.gelu(self.conv1(input_features)) - inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - inputs_embeds = inputs_embeds.permute(0, 2, 1) - - embed_pos = self.embed_positions.weight - - hidden_states = inputs_embeds + embed_pos - hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + hs = [] + for t in input_features: + t = nn.functional.gelu(self.conv1(t)) + t = nn.functional.gelu(self.conv2(t)) + t = t.permute(1, 0) + h = t + self.embed_positions.weight[:t.size(0), :] + hs.append(h) + hidden_states = torch.cat(hs) + + # #inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + # inputs_embeds = [ + # nn.functional.gelu(self.conv1(t)) for t in input_features + # ] + # inputs_embeds = [ + # nn.functional.gelu(self.conv2(t)) for t in inputs_embeds + # ] + # print(inputs_embeds[0].shape) + # inputs_embeds = torch.cat(inputs_embeds) + # #inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + # inputs_embeds = inputs_embeds.permute(0, 2, 1) + + # embed_pos = self.embed_positions.weight + + # hidden_states = inputs_embeds + embed_pos + # hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states, @@ -575,7 +592,19 @@ def get_whisper_processor( def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: - inputs["encoder"]["prompt_token_ids"] = [0] * ctx.model_config.hf_config.max_source_positions + audio, orig_sr = inputs["encoder"]["multi_modal_data"]["audio"] + processor = get_whisper_processor(ctx.model_config.model) + target_sr = processor.feature_extractor.sampling_rate + audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + if audio.size > 30 * target_sr: + # Truncate audio to 30 seconds + audio = audio[:30 * target_sr] + inputs["encoder"]["multi_modal_data"]["audio"] = (audio, target_sr) + # Calculate number of tokens after convolutions + num_tokens = (audio.size // 80 - 1) // 2 + 1 + num_tokens = (num_tokens - 2) // 2 + 1 + # Pre-allocate placeholder tokens in encoder sequence + inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens return inputs @@ -592,17 +621,14 @@ def input_mapper_for_whisper( return MultiModalKwargs() processor = get_whisper_processor(ctx.model_config.model) - target_sampling_rate = processor.feature_extractor.sampling_rate + sampling_rate = processor.feature_extractor.sampling_rate - resampled_audios = [ - librosa.resample(audio, orig_sr=sampling_rate, - target_sr=target_sampling_rate) - for audio, sampling_rate in multi_modal_data - ] + audios = [audio for audio, _ in multi_modal_data] - kwargs = processor(resampled_audios, sampling_rate=target_sampling_rate, - return_tensors="pt") + kwargs = processor(audios, sampling_rate=sampling_rate, + padding=False, return_tensors="pt") kwargs["input_features"] = kwargs["input_features"].squeeze(0) + kwargs["input_features"] = kwargs["input_features"].to(torch.float16) return MultiModalKwargs(kwargs) @@ -641,8 +667,8 @@ def forward( **kwargs, ) -> torch.Tensor: input_features = kwargs.get("input_features") - if input_features is not None: - input_features = input_features.to(torch.float16) + # if input_features is not None: + # input_features = input_features.to(torch.float16) decoder_outputs = self.model( input_features=input_features, input_ids=input_ids, diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 761b07f34d2f9..9b0843a86b38f 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -58,6 +58,8 @@ def encode(self, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) ret = tokenizer.encode(prompt) + if ret[-1] == 50257: + ret = ret[:-1] self._raise_if_input_too_long(ret, lora_request) return ret From d81d2171ff2c287153537be124852b7fa28aa3c6 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 19:12:34 +0000 Subject: [PATCH 19/44] fix tokenizer --- examples/offline_inference_audio_language.py | 10 ++--- examples/offline_inference_whisper.py | 8 ++-- vllm/inputs/preprocess.py | 12 +++-- vllm/model_executor/models/whisper.py | 45 +++++++------------ .../tokenizer_group/__init__.py | 7 +++ .../tokenizer_group/tokenizer_group.py | 18 +++++--- 6 files changed, 51 insertions(+), 49 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index b6cc0b3af49d1..b7980d5d561c9 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -66,17 +66,12 @@ def run_whisper(question: str, audio_count: int): model_name = "openai/whisper-large-v3" llm = LLM(model=model_name, - max_model_len=4096, + max_model_len=448, max_num_seqs=1, enforce_eager=True, limit_mm_per_prompt={"audio": audio_count}) - # audio_in_prompt = "".join([ - # f"Audio {idx+1}: " - # f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) - # ]) - - prompt = "<|startoftranscript|>" + prompt = "<|startoftranscript|><|en|>" stop_token_ids = None return llm, prompt, stop_token_ids @@ -111,6 +106,7 @@ def main(args): assert args.num_prompts > 0 inputs = {"prompt": prompt, "multi_modal_data": mm_data} + #inputs = {"encoder_prompt": {"prompt": "", "multi_modal_data": mm_data}, "decoder_prompt": prompt} if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 8b040851ce120..c5f318aa2dd4e 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -27,14 +27,14 @@ prompt="", multi_modal_data={"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate} ), - decoder_prompt="", + decoder_prompt="<|startoftranscript|>", ), ExplicitEncoderDecoderPrompt( encoder_prompt=TextPrompt( prompt="", multi_modal_data={"audio": AudioAsset("winning_call").audio_and_sample_rate} ), - decoder_prompt="", + decoder_prompt="<|startoftranscript|>", ), ] * 1024 @@ -44,9 +44,6 @@ top_p=1.0, min_tokens=0, max_tokens=200, - # min_tokens=40, - # max_tokens=40, - # ignore_eos=True, ) start = time.time() @@ -64,6 +61,7 @@ print(f"Encoder prompt: {encoder_prompt!r}, " f"Decoder prompt: {prompt!r}, " f"Generated text: {generated_text!r}") + print(output.outputs[0].token_ids) duration = time.time() - start diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3d606817e90aa..4beffa3d3eb9f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -118,7 +118,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: * prompt_token_ids ''' - bos_token_id = self.get_bos_token_id() assert bos_token_id is not None return [bos_token_id] @@ -439,8 +438,15 @@ def _build_enc_dec_llm_inputs( assert_never(encoder_inputs) if decoder_inputs is None: - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None) + if self.model_config.hf_config.model_type == "whisper": + # For Whisper models, the text prompt should go to the decoder. + # If no explicit encoder/decoder inputs, then copy the prompt + # from the encoder to the decoder. The encoder tokens are later + # overridden by the audio features. + dec_token_ids = encoder_inputs["prompt_token_ids"].copy() + else: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) decoder_inputs = token_inputs(dec_token_ids) elif (decoder_inputs["type"] == "token" or decoder_inputs["type"] == "multimodal"): diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 31c552baa5785..b9c1c20a45e8b 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -388,31 +388,15 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ): - hs = [] - for t in input_features: - t = nn.functional.gelu(self.conv1(t)) - t = nn.functional.gelu(self.conv2(t)) - t = t.permute(1, 0) - h = t + self.embed_positions.weight[:t.size(0), :] - hs.append(h) - hidden_states = torch.cat(hs) - - # #inputs_embeds = nn.functional.gelu(self.conv1(input_features)) - # inputs_embeds = [ - # nn.functional.gelu(self.conv1(t)) for t in input_features - # ] - # inputs_embeds = [ - # nn.functional.gelu(self.conv2(t)) for t in inputs_embeds - # ] - # print(inputs_embeds[0].shape) - # inputs_embeds = torch.cat(inputs_embeds) - # #inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - # inputs_embeds = inputs_embeds.permute(0, 2, 1) - - # embed_pos = self.embed_positions.weight - - # hidden_states = inputs_embeds + embed_pos - # hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + hidden_states = [] + for features in input_features: + embeds = nn.functional.gelu(self.conv1(features)) + embeds = nn.functional.gelu(self.conv2(embeds)) + embeds = embeds.permute(1, 0) + embeds = embeds + self.embed_positions.weight[:embeds.size(0), :] + hidden_states.append(embeds) + hidden_states = torch.cat(hidden_states) + for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states, @@ -592,14 +576,19 @@ def get_whisper_processor( def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: - audio, orig_sr = inputs["encoder"]["multi_modal_data"]["audio"] + multi_modal_data = inputs["encoder"]["multi_modal_data"] + if isinstance(multi_modal_data["audio"], list): + assert len(multi_modal_data["audio"]) == 1 + multi_modal_data["audio"] = multi_modal_data["audio"][0] + # Resample and process audio + audio, orig_sr = multi_modal_data["audio"] processor = get_whisper_processor(ctx.model_config.model) target_sr = processor.feature_extractor.sampling_rate audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) if audio.size > 30 * target_sr: # Truncate audio to 30 seconds audio = audio[:30 * target_sr] - inputs["encoder"]["multi_modal_data"]["audio"] = (audio, target_sr) + multi_modal_data["audio"] = (audio, target_sr) # Calculate number of tokens after convolutions num_tokens = (audio.size // 80 - 1) // 2 + 1 num_tokens = (num_tokens - 2) // 2 + 1 @@ -667,8 +656,6 @@ def forward( **kwargs, ) -> torch.Tensor: input_features = kwargs.get("input_features") - # if input_features is not None: - # input_features = input_features.to(torch.float16) decoder_outputs = self.model( input_features=input_features, input_ids=input_ids, diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index c0b3d2585a962..2071e1adf06f0 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -17,11 +17,18 @@ def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, lora_config: LoRAConfig): + add_special_tokens = None + if model_config.hf_config.model_type == "whisper": + # For Whisper models, the special tokens should be provided by the user + # based on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False init_kwargs = dict(tokenizer_id=model_config.tokenizer, enable_lora=bool(lora_config), max_num_seqs=scheduler_config.max_num_seqs, max_loras=lora_config.max_loras if lora_config else 0, max_input_length=None, + add_special_tokens=add_special_tokens, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 9b0843a86b38f..1659408bd820a 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -15,11 +15,13 @@ class TokenizerGroup(BaseTokenizerGroup): """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): + max_input_length: Optional[int], + add_special_tokens: Optional[bool], **tokenizer_config): self.tokenizer_id = tokenizer_id self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora self.max_input_length = max_input_length + self.add_special_tokens = add_special_tokens self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[AnyTokenizer]( @@ -57,9 +59,11 @@ def encode(self, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode(prompt) - if ret[-1] == 50257: - ret = ret[:-1] + if self.add_special_tokens is not None: + ret = tokenizer.encode(prompt, + add_special_tokens=self.add_special_tokens) + else: + ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) return ret @@ -69,7 +73,11 @@ async def encode_async( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = tokenizer.encode(prompt) + if self.add_special_tokens is not None: + ret = tokenizer.encode(prompt, + add_special_tokens=self.add_special_tokens) + else: + ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) return ret From 17712a48e77a79c5f39a89c71a4b0ec1afa30b7e Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 19:26:41 +0000 Subject: [PATCH 20/44] fix tp --- examples/offline_inference_whisper.py | 1 - vllm/model_executor/models/whisper.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index c5f318aa2dd4e..4e8b7fa4159bc 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -61,7 +61,6 @@ print(f"Encoder prompt: {encoder_prompt!r}, " f"Decoder prompt: {prompt!r}, " f"Generated text: {generated_text!r}") - print(output.outputs[0].token_ids) duration = time.time() - start diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index b9c1c20a45e8b..0c45af884395d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -16,7 +16,7 @@ InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import FastGELU -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -172,7 +172,7 @@ def _init_qkv(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - self.q_proj = RowParallelLinear( + self.q_proj = ColumnParallelLinear( input_size = embed_dim, output_size = embed_dim, bias = bias, @@ -231,7 +231,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = FastGELU() - self.fc1 = RowParallelLinear( + self.fc1 = ColumnParallelLinear( input_size = self.embed_dim, output_size = config.encoder_ffn_dim, bias = True, @@ -305,7 +305,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.encoder_attn", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = RowParallelLinear( + self.fc1 = ColumnParallelLinear( input_size = self.embed_dim, output_size = config.decoder_ffn_dim, bias = True, From b573fa92e1033dedb66f34f8917515538ce40bc8 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 19:46:12 +0000 Subject: [PATCH 21/44] clean --- examples/offline_inference_whisper.py | 37 +++++++++++---------------- vllm/worker/enc_dec_model_runner.py | 2 +- vllm/worker/model_runner.py | 1 - 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 4e8b7fa4159bc..35326a435c76e 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -1,14 +1,7 @@ -''' -Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART -''' import time from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset -from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt - -audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] dtype = "float" @@ -22,27 +15,27 @@ ) prompts = [ - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt="", - multi_modal_data={"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate} - ), - decoder_prompt="<|startoftranscript|>", - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt="", - multi_modal_data={"audio": AudioAsset("winning_call").audio_and_sample_rate} - ), - decoder_prompt="<|startoftranscript|>", - ), + { + "prompt": "<|startoftranscript|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": "<|startoftranscript|>", + } ] * 1024 # Create a sampling params object. sampling_params = SamplingParams( temperature=0, top_p=1.0, - min_tokens=0, max_tokens=200, ) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4762ff18fd001..c72c0137c7580 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -280,7 +280,7 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_len = min(seq_len, 448) + seq_len = min(seq_len, self.model_config.max_model_len) batch_size += seq_len decoder_dummy_data = self.input_registry \ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d8554c0639eb8..26fd486130ce6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1282,7 +1282,6 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_len = min(448, seq_len) batch_size += seq_len dummy_data = self.input_registry \ From 6d6cbd90044627ce2707d7747c5d457742d98190 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 19:49:57 +0000 Subject: [PATCH 22/44] clean --- examples/offline_inference_audio_language.py | 19 +------------------ vllm/attention/backends/flash_attn.py | 5 +---- vllm/inputs/preprocess.py | 1 + 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index b7980d5d561c9..050b791b62adb 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -61,23 +61,7 @@ def run_qwen2_audio(question: str, audio_count: int): return llm, prompt, stop_token_ids -# Whisper -def run_whisper(question: str, audio_count: int): - model_name = "openai/whisper-large-v3" - - llm = LLM(model=model_name, - max_model_len=448, - max_num_seqs=1, - enforce_eager=True, - limit_mm_per_prompt={"audio": audio_count}) - - prompt = "<|startoftranscript|><|en|>" - stop_token_ids = None - return llm, prompt, stop_token_ids - - -model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio, - "whisper": run_whisper} +model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio} def main(args): @@ -106,7 +90,6 @@ def main(args): assert args.num_prompts > 0 inputs = {"prompt": prompt, "multi_modal_data": mm_data} - #inputs = {"encoder_prompt": {"prompt": "", "multi_modal_data": mm_data}, "decoder_prompt": prompt} if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e08904c829f73..c69e12ad78c44 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -718,10 +718,7 @@ def forward( query = query[:num_prefill_query_tokens] prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens, ( - f"decode_query.shape: {decode_query.shape}, " - f"num_decode_query_tokens: {num_decode_query_tokens}" - ) + assert decode_query.shape[0] == num_decode_query_tokens if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 4beffa3d3eb9f..2efbf220eb794 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -118,6 +118,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: * prompt_token_ids ''' + bos_token_id = self.get_bos_token_id() assert bos_token_id is not None return [bos_token_id] From 94a867b60a12be9376bc003f68ffca57993a3921 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 17 Dec 2024 19:53:04 +0000 Subject: [PATCH 23/44] udpate --- vllm/core/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 04dfb0a261bc7..ff2dbf83630fa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1381,7 +1381,8 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=(seq_group.multi_modal_data or seq_group.encoder_seq.multi_modal_data) + multi_modal_data=(seq_group.multi_modal_data or + seq_group.encoder_seq.multi_modal_data) if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None, From 787708a65281b4cafd707eb172148dd5cad92b57 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 18 Dec 2024 03:40:10 +0000 Subject: [PATCH 24/44] add test --- .../models/encoder_decoder/audio/__init__.py | 0 .../encoder_decoder/audio/test_whisper.py | 107 ++++++++++++++++++ vllm/config.py | 1 + vllm/model_executor/models/whisper.py | 83 +++++++------- 4 files changed, 151 insertions(+), 40 deletions(-) create mode 100644 tests/models/encoder_decoder/audio/__init__.py create mode 100644 tests/models/encoder_decoder/audio/test_whisper.py diff --git a/tests/models/encoder_decoder/audio/__init__.py b/tests/models/encoder_decoder/audio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py new file mode 100644 index 0000000000000..6ddbf8f579a07 --- /dev/null +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -0,0 +1,107 @@ +"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling. + +Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`. +""" +from typing import Optional + +import pytest + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset + +from ....utils import fork_new_process_for_each_test, multi_gpu_test + + +PROMPTS = [ + { + "prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + } +] + +EXPECTED = { + "openai/whisper-medium": [ + " The first words I spoke in the original phonograph, a little piece" + " of practical poetry. Mary had a little lamb, its fleece was quite as" + " slow, and everywhere that Mary went the lamb was sure to go.", + " And the old one pitch on the way to Edgar Martinez swung on the line" + " down the left field line for Obeysmith. Here comes Joy. Here is" + " Jorgen at third base. They're gonna wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh" + " my." + ], + "openai/whisper-large-v3": [ + " The first words I spoke in the original phonograph. A little piece" + " of practical poetry. Mary had a little lamb, its fleece was white as" + " snow, and everywhere that Mary went, the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line," + " down the left field line for a base hit. Here comes Joy. Here is" + " Junior to third base. They're going to wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh," + " my." + ] +} + + +def run_test( + model: str, + *, + enforce_eager: bool, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + prompts = PROMPTS * 10 + expected = EXPECTED[model] * 10 + + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + ) + + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, + ) + + outputs = llm.generate(prompts, sampling_params) + + for output, expected in zip(outputs, expected): + print(output.outputs[0].text) + assert output.outputs[0].text == expected + + +@fork_new_process_for_each_test +@pytest.mark.parametrize( + "model", ["openai/whisper-medium", "openai/whisper-large-v3"] +) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_models(model, enforce_eager) -> None: + run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +def test_models_distributed(model, enforce_eager, + distributed_executor_backend) -> None: + run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend) diff --git a/vllm/config.py b/vllm/config.py index 08a7b607630af..d481c8c31b5a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1961,6 +1961,7 @@ def _get_and_verify_max_len( # Command-R "model_max_length", # Others + "max_length", "max_sequence_length", "max_seq_length", "seq_len", diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0c45af884395d..ad9c1a7d71b5f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -37,14 +37,18 @@ logger = init_logger(__name__) -def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: +def sinusoids( + length: int, channels: int, max_timescale: float = 10000 +) -> torch.Tensor: """Returns sinusoids for positional embedding""" if channels % 2 != 0: raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + f"Number of channels has to be divisible by 2 for sinusoidal " + f"positional embeddings, got {channels} channels." ) log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(channels // 2)) scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) @@ -269,10 +273,12 @@ def forward( hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + torch.isinf(hidden_states).any() or + torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, + max=clamp_value) return hidden_states @@ -366,11 +372,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_mel_bins = config.num_mel_bins self.padding_idx = config.pad_token_id self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) - self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_scale = ( + math.sqrt(embed_dim) if config.scale_embedding else 1.0) + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, + padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, + padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, + embed_dim) self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, @@ -380,7 +390,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layer_norm = nn.LayerNorm(config.d_model) with torch.no_grad(): - self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape)) def forward( self, @@ -417,10 +428,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_scale = ( + math.sqrt(config.d_model) if config.scale_embedding else 1.0) - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) - self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding( + self.max_target_positions, config.d_model) self.start_layer, self.end_layer, self.layers = make_layers( config.decoder_layers, lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, @@ -463,7 +477,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features: Optional[torch.FloatTensor], + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -550,32 +564,16 @@ def get_whisper_processor( **kwargs, ) -> WhisperProcessor: """Gets an whisper processor for the given model name via HuggingFace.""" - try: - processor: WhisperProcessor = WhisperProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the whisper processor. If the whisper processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return processor + return WhisperProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) -def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: +def input_processor_for_whisper(ctx: InputContext, inputs): multi_modal_data = inputs["encoder"]["multi_modal_data"] if isinstance(multi_modal_data["audio"], list): assert len(multi_modal_data["audio"]) == 1 @@ -625,7 +623,8 @@ def input_mapper_for_whisper( @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) @INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_whisper_audio_tokens) class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -634,6 +633,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self.dtype = vllm_config.model_config.dtype self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size @@ -655,7 +655,10 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]] input_features = kwargs.get("input_features") + if input_features is not None: + input_features = [feat.to(self.dtype) for feat in input_features] decoder_outputs = self.model( input_features=input_features, input_ids=input_ids, From e943905c39ae483814df5df8ade2cc8e880ab960 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 18 Dec 2024 22:21:12 +0000 Subject: [PATCH 25/44] some cleanup --- requirements-common.txt | 1 + .../encoder_decoder/audio/test_whisper.py | 4 +- tests/models/registry.py | 1 + vllm/inputs/preprocess.py | 23 +++- vllm/model_executor/models/whisper.py | 128 +++++++++--------- .../tokenizer_group/__init__.py | 7 - .../tokenizer_group/ray_tokenizer_group.py | 26 ++-- .../tokenizer_group/tokenizer_group.py | 18 +-- vllm/worker/enc_dec_model_runner.py | 1 - 9 files changed, 107 insertions(+), 102 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 11984260c580d..4d6b2f5ccd57a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -35,3 +35,4 @@ setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we n einops # Required for Qwen2-VL. compressed-tensors == 0.8.0 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging torch.compile +librosa >= 0.10.2 # required for audio processing including Whisper diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py index 6ddbf8f579a07..885e8c7035537 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -35,8 +35,8 @@ EXPECTED = { "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its fleece was quite as" - " slow, and everywhere that Mary went the lamb was sure to go.", + " of practical poetry. Mary had a little lamb, its fleece was white as" + " snow, and everywhere that Mary went the lamb would shun it all.", " And the old one pitch on the way to Edgar Martinez swung on the line" " down the left field line for Obeysmith. Here comes Joy. Here is" " Jorgen at third base. They're gonna wave him in. The throw to the" diff --git a/tests/models/registry.py b/tests/models/registry.py index a89518820045f..379248c686bbe 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -124,6 +124,7 @@ class _HfExamplesInfo: # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 2efbf220eb794..d4db5ebad8832 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -184,10 +184,16 @@ def _tokenize_prompt( corresponding token IDs. """ tokenizer = self.get_tokenizer_group() - + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False return tokenizer.encode(request_id=request_id, prompt=prompt, - lora_request=lora_request) + lora_request=lora_request, + add_special_tokens=add_special_tokens) async def _tokenize_prompt_async( self, @@ -197,10 +203,15 @@ async def _tokenize_prompt_async( ) -> List[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group() - - return await tokenizer.encode_async(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False + return await tokenizer.encode_async( + request_id=request_id, prompt=prompt, lora_request=lora_request, + add_special_tokens=add_special_tokens) def _can_process_multimodal(self) -> bool: model_config = self.model_config diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ad9c1a7d71b5f..37f39b79f6a65 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -6,16 +6,15 @@ import numpy as np import torch from torch import nn -from transformers import WhisperConfig, WhisperProcessor +from transformers import WhisperProcessor +from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext from vllm.logger import init_logger -from vllm.model_executor.layers.activation import FastGELU +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -24,35 +23,17 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import SequenceData from .interfaces import SupportsMultiModal -from .utils import AutoWeightsLoader, make_layers, maybe_prefix +from .utils import AutoWeightsLoader, make_layers, WeightsMapper logger = init_logger(__name__) -def sinusoids( - length: int, channels: int, max_timescale: float = 10000 -) -> torch.Tensor: - """Returns sinusoids for positional embedding""" - if channels % 2 != 0: - raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal " - f"positional embeddings, got {channels} channels." - ) - log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * - torch.arange(channels // 2)) - scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) - return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) - - class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): @@ -216,6 +197,39 @@ def forward( return output +class WhisperMLP(nn.Module): + + def __init__( + self, + embed_dim: int, + ffn_dim: int, + act_fn: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.activation_fn = get_act_fn(act_fn) + self.fc1 = ColumnParallelLinear( + input_size=embed_dim, + output_size=ffn_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size=ffn_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + class WhisperEncoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -234,20 +248,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.self_attn", ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.activation_fn = FastGELU() - self.fc1 = ColumnParallelLinear( - input_size = self.embed_dim, - output_size = config.encoder_ffn_dim, - bias = True, + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.encoder_ffn_dim, + act_fn=config.activation_function, quant_config=quant_config, - prefix=f"{prefix}.fc1", - ) - self.fc2 = RowParallelLinear( - input_size = config.encoder_ffn_dim, - output_size = self.embed_dim, - bias = True, - quant_config=quant_config, - prefix=f"{prefix}.fc2", + prefix=f"{prefix}.mlp", ) self.final_layer_norm = nn.LayerNorm(self.embed_dim) @@ -267,9 +273,7 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) - hidden_states, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states, _ = self.fc2(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( @@ -291,41 +295,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.embed_dim = config.d_model self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, + embed_dim=config.d_model, num_heads=config.decoder_attention_heads, attn_type=AttentionType.DECODER, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.activation_fn = FastGELU() - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(config.d_model) self.encoder_attn = WhisperCrossAttention( - embed_dim=self.embed_dim, + embed_dim=config.d_model, num_heads=config.decoder_attention_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.encoder_attn", ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = ColumnParallelLinear( - input_size = self.embed_dim, - output_size = config.decoder_ffn_dim, - bias = True, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - ) - self.fc2 = RowParallelLinear( - input_size = config.decoder_ffn_dim, - output_size = self.embed_dim, - bias = True, + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model) + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.decoder_ffn_dim, + act_fn=config.activation_function, quant_config=quant_config, - prefix=f"{prefix}.fc2", + prefix=f"{prefix}.mlp", ) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.final_layer_norm = nn.LayerNorm(config.d_model) def forward( self, @@ -355,9 +349,7 @@ def forward( residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) - hidden_states, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states, _ = self.fc2(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @@ -685,5 +677,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) - return loader.load_weights((name, loaded_weight) - for name, loaded_weight in weights) \ No newline at end of file + loaded_weights = [(name, loaded_weight) + for name, loaded_weight in weights] + mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) + return loader.load_weights(loaded_weights, mapper=mapper) \ No newline at end of file diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 2071e1adf06f0..c0b3d2585a962 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -17,18 +17,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, lora_config: LoRAConfig): - add_special_tokens = None - if model_config.hf_config.model_type == "whisper": - # For Whisper models, the special tokens should be provided by the user - # based on the task and language of their request. Also needed to avoid - # appending an EOS token to the prompt which disrupts generation. - add_special_tokens = False init_kwargs = dict(tokenizer_id=model_config.tokenizer, enable_lora=bool(lora_config), max_num_seqs=scheduler_config.max_num_seqs, max_loras=lora_config.max_loras if lora_config else 0, max_input_length=None, - add_special_tokens=add_special_tokens, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 9a999a0d6067d..76a5c9b010fbb 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -112,7 +112,8 @@ def _finalize_encode(self, actor: ray.ObjectRef, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -132,7 +133,8 @@ def encode(self, ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request)) + lora_request=lora_request, + add_special_tokens=add_special_tokens)) except ActorDiedError as e: # If the actor is dead, we first try to reinitialize it. logger.warning("%s died with ActorDiedError, reinitializing.", @@ -143,7 +145,8 @@ def encode(self, ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request)) + lora_request=lora_request, + add_special_tokens=add_special_tokens)) except ActorDiedError as e: logger.error( "%s died for second time in a row, marking " @@ -160,7 +163,8 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -177,9 +181,10 @@ async def encode_async( actor_is_alive = True original_actor = actor try: - ret = await actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + ret = await actor.encode.remote( + request_id=request_id, prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) except ActorDiedError as e: # If the actor is dead, we first try to reinitialize it. logger.warning("%s died with ActorDiedError, reinitializing.", @@ -187,9 +192,10 @@ async def encode_async( exc_info=e) actor = self._init_actor() try: - ret = await actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + ret = await actor.encode.remote( + request_id=request_id, prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) except ActorDiedError as e: logger.error( "%s died for second time in a row, marking " diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 1659408bd820a..d9b5c300331db 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -15,13 +15,11 @@ class TokenizerGroup(BaseTokenizerGroup): """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], - add_special_tokens: Optional[bool], **tokenizer_config): + max_input_length: Optional[int], **tokenizer_config): self.tokenizer_id = tokenizer_id self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora self.max_input_length = max_input_length - self.add_special_tokens = add_special_tokens self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[AnyTokenizer]( @@ -57,11 +55,12 @@ def _raise_if_input_too_long(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - if self.add_special_tokens is not None: + if add_special_tokens is not None: ret = tokenizer.encode(prompt, - add_special_tokens=self.add_special_tokens) + add_special_tokens=add_special_tokens) else: ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) @@ -71,11 +70,12 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - if self.add_special_tokens is not None: + if add_special_tokens is not None: ret = tokenizer.encode(prompt, - add_special_tokens=self.add_special_tokens) + add_special_tokens=add_special_tokens) else: ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index c72c0137c7580..f3719ccb3e1f7 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -280,7 +280,6 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_len = min(seq_len, self.model_config.max_model_len) batch_size += seq_len decoder_dummy_data = self.input_registry \ From 606642e31440e23e4a1b293da86adb27a576a7eb Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 00:12:32 +0000 Subject: [PATCH 26/44] formatting --- .../encoder_decoder/audio/test_whisper.py | 17 ++-- vllm/inputs/preprocess.py | 4 +- vllm/model_executor/models/whisper.py | 89 +++++++++++-------- .../tokenizer_group/ray_tokenizer_group.py | 6 +- 4 files changed, 70 insertions(+), 46 deletions(-) diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py index 885e8c7035537..d61cc29d959aa 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -65,8 +65,8 @@ def run_test( tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ) -> None: - prompts = PROMPTS * 10 - expected = EXPECTED[model] * 10 + prompt_list = PROMPTS * 10 + expected_list = EXPECTED[model] * 10 llm = LLM( model=model, @@ -81,17 +81,16 @@ def run_test( max_tokens=200, ) - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate(prompt_list, sampling_params) - for output, expected in zip(outputs, expected): + for output, expected in zip(outputs, expected_list): print(output.outputs[0].text) assert output.outputs[0].text == expected @fork_new_process_for_each_test -@pytest.mark.parametrize( - "model", ["openai/whisper-medium", "openai/whisper-large-v3"] -) +@pytest.mark.parametrize("model", + ["openai/whisper-medium", "openai/whisper-large-v3"]) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_models(model, enforce_eager) -> None: run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1) @@ -103,5 +102,7 @@ def test_models(model, enforce_eager) -> None: @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) def test_models_distributed(model, enforce_eager, distributed_executor_backend) -> None: - run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=2, + run_test(model, + enforce_eager=enforce_eager, + tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d4db5ebad8832..9ccae3ab11f75 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -210,7 +210,9 @@ async def _tokenize_prompt_async( # appending an EOS token to the prompt which disrupts generation. add_special_tokens = False return await tokenizer.encode_async( - request_id=request_id, prompt=prompt, lora_request=lora_request, + request_id=request_id, + prompt=prompt, + lora_request=lora_request, add_special_tokens=add_special_tokens) def _can_process_multimodal(self) -> bool: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 37f39b79f6a65..bbccd8c2c33a2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -35,7 +35,10 @@ class WhisperPositionalEmbedding(nn.Embedding): - def __init__(self, num_positions: int, embedding_dim: int, + + def __init__(self, + num_positions: int, + embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) @@ -44,6 +47,7 @@ def forward(self, position_ids): class WhisperAttention(nn.Module): + def __init__( self, embed_dim: int, @@ -77,15 +81,14 @@ def __init__( if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: " - f"{self.embed_dim} and `num_heads`: {num_heads})." - ) + f"{self.embed_dim} and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) self.out_proj = RowParallelLinear( - input_size = embed_dim, - output_size = embed_dim, - bias = bias, + input_size=embed_dim, + output_size=embed_dim, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) @@ -99,7 +102,8 @@ def __init__( prefix=f"{prefix}.attn", ) - def _init_qkv(self, + def _init_qkv( + self, embed_dim: int, bias: bool = True, quant_config: Optional[QuantizationConfig] = None, @@ -124,7 +128,11 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, attn_type=self.attn_type) output, _ = self.out_proj(attn_output) @@ -133,6 +141,7 @@ def forward( class WhisperCrossAttention(WhisperAttention): + def __init__( self, embed_dim: int, @@ -151,16 +160,17 @@ def __init__( prefix=prefix, ) - def _init_qkv(self, + def _init_qkv( + self, embed_dim: int, bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: self.q_proj = ColumnParallelLinear( - input_size = embed_dim, - output_size = embed_dim, - bias = bias, + input_size=embed_dim, + output_size=embed_dim, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.q_proj", ) @@ -189,7 +199,11 @@ def forward( else: k = v = None - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, attn_type=AttentionType.ENCODER_DECODER) output, _ = self.out_proj(attn_output) @@ -256,7 +270,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.mlp", ) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - + def forward( self, hidden_states: torch.Tensor, @@ -277,11 +291,11 @@ def forward( hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or - torch.isnan(hidden_states).any() - ): + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, max=clamp_value) return hidden_states @@ -320,7 +334,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.mlp", ) self.final_layer_norm = nn.LayerNorm(config.d_model) - + def forward( self, hidden_states: torch.Tensor, @@ -330,11 +344,9 @@ def forward( ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata - ) + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) hidden_states = residual + hidden_states residual = hidden_states @@ -364,12 +376,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_mel_bins = config.num_mel_bins self.padding_idx = config.pad_token_id self.max_source_positions = config.max_source_positions - self.embed_scale = ( - math.sqrt(embed_dim) if config.scale_embedding else 1.0) + self.embed_scale = (math.sqrt(embed_dim) + if config.scale_embedding else 1.0) - self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, + self.conv1 = nn.Conv1d(self.num_mel_bins, + embed_dim, + kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, + self.conv2 = nn.Conv1d(embed_dim, + embed_dim, + kernel_size=3, + stride=2, padding=1) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) @@ -384,7 +401,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): with torch.no_grad(): self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape)) - + def forward( self, input_features: Union[torch.Tensor, List[torch.Tensor]], @@ -420,8 +437,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions - self.embed_scale = ( - math.sqrt(config.d_model) if config.scale_embedding else 1.0) + self.embed_scale = (math.sqrt(config.d_model) + if config.scale_embedding else 1.0) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) @@ -434,7 +451,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) - + def forward( self, input_ids, @@ -460,6 +477,7 @@ def forward( class WhisperModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.encoder = WhisperEncoder(vllm_config=vllm_config, @@ -604,8 +622,10 @@ def input_mapper_for_whisper( audios = [audio for audio, _ in multi_modal_data] - kwargs = processor(audios, sampling_rate=sampling_rate, - padding=False, return_tensors="pt") + kwargs = processor(audios, + sampling_rate=sampling_rate, + padding=False, + return_tensors="pt") kwargs["input_features"] = kwargs["input_features"].squeeze(0) kwargs["input_features"] = kwargs["input_features"].to(torch.float16) @@ -623,7 +643,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.dtype = vllm_config.model_config.dtype diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 76a5c9b010fbb..3f7627e11ae5e 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -182,7 +182,8 @@ async def encode_async( original_actor = actor try: ret = await actor.encode.remote( - request_id=request_id, prompt=prompt, + request_id=request_id, + prompt=prompt, lora_request=lora_request, add_special_tokens=add_special_tokens) except ActorDiedError as e: @@ -193,7 +194,8 @@ async def encode_async( actor = self._init_actor() try: ret = await actor.encode.remote( - request_id=request_id, prompt=prompt, + request_id=request_id, + prompt=prompt, lora_request=lora_request, add_special_tokens=add_special_tokens) except ActorDiedError as e: From fe8e245233cb6685c536fe57e72f13e7ab76be6a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 00:44:33 +0000 Subject: [PATCH 27/44] format --- tests/models/encoder_decoder/audio/test_whisper.py | 1 - vllm/model_executor/models/whisper.py | 2 +- .../tokenizer_group/base_tokenizer_group.py | 6 ++++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py index d61cc29d959aa..81761d2a2b679 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -11,7 +11,6 @@ from ....utils import fork_new_process_for_each_test, multi_gpu_test - PROMPTS = [ { "prompt": diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index bbccd8c2c33a2..f9c25564750ae 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -29,7 +29,7 @@ from vllm.sequence import SequenceData from .interfaces import SupportsMultiModal -from .utils import AutoWeightsLoader, make_layers, WeightsMapper +from .utils import AutoWeightsLoader, WeightsMapper, make_layers logger = init_logger(__name__) diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 8f78ef65bbf1a..e6cc7cd4e2e3a 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -32,7 +32,8 @@ def get_max_input_len( def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @@ -41,7 +42,8 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass From b59fddb624f28a8a6535b086b8e16da518ac2f77 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 01:11:44 +0000 Subject: [PATCH 28/44] mypy --- vllm/core/scheduler.py | 3 +-- vllm/sequence.py | 18 +++++++++++++++--- .../tokenizer_group/tokenizer_group.py | 14 ++++---------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ff2dbf83630fa..b3d396f9cedda 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1381,8 +1381,7 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=(seq_group.multi_modal_data or - seq_group.encoder_seq.multi_modal_data) + multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None, diff --git a/vllm/sequence.py b/vllm/sequence.py index b0f3c1cc3609f..86a3a0aed164b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -697,15 +697,27 @@ def token_type_ids(self) -> Optional[List[int]]: @property def multi_modal_data(self) -> MultiModalDataDict: - return self.first_seq.multi_modal_data + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_data + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_data + return None @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.first_seq.multi_modal_placeholders + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_placeholders + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_placeholders + return None @property def mm_processor_kwargs(self) -> Dict[str, Any]: - return self.first_seq.mm_processor_kwargs + if self.first_seq.multi_modal_data: + return self.first_seq.mm_processor_kwargs + elif self.encoder_seq is not None: + return self.encoder_seq.mm_processor_kwargs + return None @property def lora_int_id(self) -> int: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index d9b5c300331db..b289e7e2430ae 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -58,11 +58,8 @@ def encode(self, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - if add_special_tokens is not None: - ret = tokenizer.encode(prompt, - add_special_tokens=add_special_tokens) - else: - ret = tokenizer.encode(prompt) + ret = tokenizer.encode( # type: ignore[call-arg] + prompt, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -73,11 +70,8 @@ async def encode_async( lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - if add_special_tokens is not None: - ret = tokenizer.encode(prompt, - add_special_tokens=add_special_tokens) - else: - ret = tokenizer.encode(prompt) + ret = tokenizer.encode( # type: ignore[call-arg] + prompt, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret From d66cd424298940473297af3f8743278966eae47c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 01:13:23 +0000 Subject: [PATCH 29/44] mypy --- vllm/sequence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 86a3a0aed164b..f24ce1a853a45 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -701,7 +701,7 @@ def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return None + return {} @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -709,7 +709,7 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.first_seq.multi_modal_placeholders elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_placeholders - return None + return {} @property def mm_processor_kwargs(self) -> Dict[str, Any]: @@ -717,7 +717,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]: return self.first_seq.mm_processor_kwargs elif self.encoder_seq is not None: return self.encoder_seq.mm_processor_kwargs - return None + return {} @property def lora_int_id(self) -> int: From 6ba1afc6430a3659048c9b91c103cd11a668d62c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 01:18:05 +0000 Subject: [PATCH 30/44] format --- vllm/transformers_utils/tokenizer_group/tokenizer_group.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index b289e7e2430ae..b333a7fdca5d8 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -59,7 +59,8 @@ def encode(self, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) ret = tokenizer.encode( # type: ignore[call-arg] - prompt, add_special_tokens=add_special_tokens) + prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -71,7 +72,8 @@ async def encode_async( add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) ret = tokenizer.encode( # type: ignore[call-arg] - prompt, add_special_tokens=add_special_tokens) + prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret From 26fd92afbb5577f6ff230962ff4696a710f61566 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 04:04:30 +0000 Subject: [PATCH 31/44] fix tests --- vllm/config.py | 3 ++- .../tokenizer_group/tokenizer_group.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d481c8c31b5a3..c169fa03be249 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1960,8 +1960,9 @@ def _get_and_verify_max_len( "seq_length", # Command-R "model_max_length", + # Whisper + "max_target_positions", # Others - "max_length", "max_sequence_length", "max_seq_length", "seq_len", diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index b333a7fdca5d8..b0805ead1367d 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -58,9 +58,12 @@ def encode(self, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode( # type: ignore[call-arg] - prompt, - add_special_tokens=add_special_tokens) + if add_special_tokens is not None: + ret = tokenizer.encode( # type: ignore[call-arg] + prompt, + add_special_tokens=add_special_tokens) + else: + ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) return ret @@ -71,9 +74,12 @@ async def encode_async( lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = tokenizer.encode( # type: ignore[call-arg] - prompt, - add_special_tokens=add_special_tokens) + if add_special_tokens is not None: + ret = tokenizer.encode( # type: ignore[call-arg] + prompt, + add_special_tokens=add_special_tokens) + else: + ret = tokenizer.encode(prompt) self._raise_if_input_too_long(ret, lora_request) return ret From 4566b10c0809b0791656b6d0dd50847c31c5932b Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 16:56:00 +0000 Subject: [PATCH 32/44] librosa --- requirements-common.txt | 1 - vllm/model_executor/models/whisper.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 4d6b2f5ccd57a..11984260c580d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -35,4 +35,3 @@ setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we n einops # Required for Qwen2-VL. compressed-tensors == 0.8.0 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging torch.compile -librosa >= 0.10.2 # required for audio processing including Whisper diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index f9c25564750ae..4ea23f23e0f84 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -2,7 +2,6 @@ from functools import lru_cache from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union -import librosa import numpy as np import torch from torch import nn @@ -583,7 +582,17 @@ def get_whisper_processor( ) +def _resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + try: + import librosa + except ImportError as exc: + raise ImportError( + "Please install vllm[audio] for audio support.") from exc + return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + def input_processor_for_whisper(ctx: InputContext, inputs): + multi_modal_data = inputs["encoder"]["multi_modal_data"] if isinstance(multi_modal_data["audio"], list): assert len(multi_modal_data["audio"]) == 1 @@ -592,7 +601,7 @@ def input_processor_for_whisper(ctx: InputContext, inputs): audio, orig_sr = multi_modal_data["audio"] processor = get_whisper_processor(ctx.model_config.model) target_sr = processor.feature_extractor.sampling_rate - audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + audio = _resample(audio, orig_sr=orig_sr, target_sr=target_sr) if audio.size > 30 * target_sr: # Truncate audio to 30 seconds audio = audio[:30 * target_sr] From 1fe41fc59b649dcba143e8e3077d2e477bf73ca6 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 19 Dec 2024 22:01:13 +0000 Subject: [PATCH 33/44] small --- vllm/multimodal/processing.py | 26 ++++--------------- vllm/transformers_utils/tokenizer.py | 19 ++++++++++++++ .../tokenizer_group/tokenizer_group.py | 17 ++++-------- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..c09b4bfbc34d4 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -14,7 +14,7 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from .audio import resample_audio @@ -55,24 +55,6 @@ def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": ) -def _encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: bool = False, -) -> list[int]: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=...)`. - """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - - return tokenizer.encode(text, add_special_tokens=add_special_tokens) - - @lru_cache(maxsize=2048) def _cached_encode( tokenizer: AnyTokenizer, @@ -80,7 +62,8 @@ def _cached_encode( *, add_special_tokens: bool = False, ) -> list[int]: - return _encode(tokenizer, text, add_special_tokens=add_special_tokens) + return encode_tokens(tokenizer, text, + add_special_tokens=add_special_tokens) def _decode( @@ -763,7 +746,8 @@ def _apply_prompt_replacements( mm_item_counts, ) - token_ids = _encode(tokenizer, text) + token_ids = encode_tokens(tokenizer, text, + add_special_tokens=False) matched_repls = [match.prompt_repl for match in text_matches] placeholders = self._find_placeholders(matched_repls, token_ids, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e6701f4c4b835..42b2f095bc543 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -21,6 +21,25 @@ MistralTokenizer] +def encode_tokens( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + elif add_special_tokens is not None: + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + return tokenizer.encode(text) + + def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index fa3215684acc9..0b20be2ac071f 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -3,6 +3,7 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, + encode_tokens, get_lora_tokenizer, get_lora_tokenizer_async, get_tokenizer) @@ -58,12 +59,8 @@ def encode(self, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - if add_special_tokens is not None: - ret = tokenizer.encode( # type: ignore[call-arg] - prompt, - add_special_tokens=add_special_tokens) - else: - ret = tokenizer.encode(prompt) + ret = encode_tokens(tokenizer, prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -74,12 +71,8 @@ async def encode_async( lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - if add_special_tokens is not None: - ret = tokenizer.encode( # type: ignore[call-arg] - prompt, - add_special_tokens=add_special_tokens) - else: - ret = tokenizer.encode(prompt) + ret = encode_tokens(tokenizer, prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret From 1c16ad2b5112b46604314839a0387961ff49caeb Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 20 Dec 2024 17:16:24 +0000 Subject: [PATCH 34/44] updates --- vllm/model_executor/models/whisper.py | 127 ++++++++++++++++---------- 1 file changed, 77 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 4ea23f23e0f84..d2318ebc8bfca 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,11 +1,10 @@ import math -from functools import lru_cache -from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) import numpy as np import torch from torch import nn -from transformers import WhisperProcessor from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionMetadata, AttentionType @@ -24,8 +23,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.audio import resample_audio from vllm.sequence import SequenceData +from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsMultiModal from .utils import AutoWeightsLoader, WeightsMapper, make_layers @@ -33,6 +35,11 @@ logger = init_logger(__name__) +class WhisperAudioInputs(TypedDict): + input_features: NestedTensors + """Shape: `(batch_size, 128, M)`""" + + class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, @@ -192,6 +199,8 @@ def forward( ): q, _ = self.q_proj(hidden_states) + # Encoder hidden states are only computed once during prefill phase. + # Afterwards, the keys and values should be available in the kv-cache. if encoder_hidden_states is not None: kv, _ = self.kv_proj(encoder_hidden_states) k, v = kv.split([self.kv_size, self.kv_size], dim=-1) @@ -459,7 +468,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ): - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions @@ -474,6 +483,12 @@ def forward( hidden_states = self.layer_norm(hidden_states) return hidden_states + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + class WhisperModel(nn.Module): @@ -492,16 +507,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - if input_features is not None: - # Prefill encoder kv-caches - encoder_outputs = self.encoder( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) - else: - encoder_outputs = None - + encoder_outputs = self.get_encoder_outputs( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, @@ -511,6 +521,20 @@ def forward( ) return decoder_outputs + def get_encoder_outputs( + self, + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> Optional[torch.Tensor]: + if input_features is None: + return None + return self.encoder( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -564,33 +588,6 @@ def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, ) -@lru_cache -def get_whisper_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - revision: Optional[str] = None, - **kwargs, -) -> WhisperProcessor: - """Gets an whisper processor for the given model name via HuggingFace.""" - return WhisperProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs, - ) - - -def _resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: - try: - import librosa - except ImportError as exc: - raise ImportError( - "Please install vllm[audio] for audio support.") from exc - return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) - - def input_processor_for_whisper(ctx: InputContext, inputs): multi_modal_data = inputs["encoder"]["multi_modal_data"] @@ -599,9 +596,9 @@ def input_processor_for_whisper(ctx: InputContext, inputs): multi_modal_data["audio"] = multi_modal_data["audio"][0] # Resample and process audio audio, orig_sr = multi_modal_data["audio"] - processor = get_whisper_processor(ctx.model_config.model) + processor = cached_get_processor(ctx.model_config.model) target_sr = processor.feature_extractor.sampling_rate - audio = _resample(audio, orig_sr=orig_sr, target_sr=target_sr) + audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) if audio.size > 30 * target_sr: # Truncate audio to 30 seconds audio = audio[:30 * target_sr] @@ -626,7 +623,7 @@ def input_mapper_for_whisper( if len(multi_modal_data) == 0: return MultiModalKwargs() - processor = get_whisper_processor(ctx.model_config.model) + processor = cached_get_processor(ctx.model_config.model) sampling_rate = processor.feature_extractor.sampling_rate audios = [audio for audio, _ in multi_modal_data] @@ -675,12 +672,9 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: - input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]] - input_features = kwargs.get("input_features") - if input_features is not None: - input_features = [feat.to(self.dtype) for feat in input_features] + audio_input = self._parse_and_validate_audio_input(**kwargs) decoder_outputs = self.model( - input_features=input_features, + input_features=audio_input["input_features"], input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -688,6 +682,39 @@ def forward( ) return decoder_outputs + def get_multimodal_embeddings( + self, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + return self.model.get_encoder_outputs( + audio_input["input_features"], + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + return self.model.decoder.get_input_embeddings(input_ids) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> WhisperAudioInputs: + input_features = kwargs.pop("input_features", None) + + if input_features is not None: + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(input_features)}") + input_features = [feat.to(self.dtype) for feat in input_features] + + return WhisperAudioInputs(input_features=input_features) + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.proj_out, hidden_states, From 72822800ba090a7d77afdf3415d9c8683ea0db99 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 20 Dec 2024 17:28:56 +0000 Subject: [PATCH 35/44] lint --- vllm/model_executor/models/whisper.py | 7 +++---- vllm/multimodal/processing.py | 6 ++++-- .../tokenizer_group/tokenizer_group.py | 9 +++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d2318ebc8bfca..e8b8c456aeae7 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,6 +1,6 @@ import math -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import numpy as np import torch @@ -589,7 +589,6 @@ def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, def input_processor_for_whisper(ctx: InputContext, inputs): - multi_modal_data = inputs["encoder"]["multi_modal_data"] if isinstance(multi_modal_data["audio"], list): assert len(multi_modal_data["audio"]) == 1 @@ -710,7 +709,7 @@ def _parse_and_validate_audio_input( if input_features is not None: if not isinstance(input_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio features. " - f"Got type: {type(input_features)}") + f"Got type: {type(input_features)}") input_features = [feat.to(self.dtype) for feat in input_features] return WhisperAudioInputs(input_features=input_features) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c09b4bfbc34d4..f75b835e8464b 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -62,7 +62,8 @@ def _cached_encode( *, add_special_tokens: bool = False, ) -> list[int]: - return encode_tokens(tokenizer, text, + return encode_tokens(tokenizer, + text, add_special_tokens=add_special_tokens) @@ -746,7 +747,8 @@ def _apply_prompt_replacements( mm_item_counts, ) - token_ids = encode_tokens(tokenizer, text, + token_ids = encode_tokens(tokenizer, + text, add_special_tokens=False) matched_repls = [match.prompt_repl for match in text_matches] diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 0b20be2ac071f..6dc2f90561873 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,8 +2,7 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - encode_tokens, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, get_lora_tokenizer_async, get_tokenizer) @@ -59,7 +58,8 @@ def encode(self, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = encode_tokens(tokenizer, prompt, + ret = encode_tokens(tokenizer, + prompt, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -71,7 +71,8 @@ async def encode_async( lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = encode_tokens(tokenizer, prompt, + ret = encode_tokens(tokenizer, + prompt, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret From 3442852fe9d9e15061768ce372a6fd1d048d2968 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 20 Dec 2024 18:42:40 +0000 Subject: [PATCH 36/44] add todos --- vllm/model_executor/models/whisper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e8b8c456aeae7..860e84fc955e6 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -687,6 +687,8 @@ def get_multimodal_embeddings( attn_metadata: AttentionMetadata, **kwargs, ) -> Optional[NestedTensors]: + # TODO: This method does not obey the interface for SupportsMultiModal. + # Refactor this once encoder/decoder support is implemented in V1. audio_input = self._parse_and_validate_audio_input(**kwargs) return self.model.get_encoder_outputs( audio_input["input_features"], @@ -700,6 +702,9 @@ def get_input_embeddings( multimodal_embeddings: Optional[NestedTensors] = None, attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: + # TODO: This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. Refactor this once + # encoder/decoder support is implemented in V1. return self.model.decoder.get_input_embeddings(input_ids) def _parse_and_validate_audio_input( From e0cc63e857544da4510a019c740cd7424ddf847a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 20 Dec 2024 23:13:24 +0000 Subject: [PATCH 37/44] bugfix --- vllm/model_executor/models/whisper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 860e84fc955e6..4207c4249aae5 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -602,9 +602,8 @@ def input_processor_for_whisper(ctx: InputContext, inputs): # Truncate audio to 30 seconds audio = audio[:30 * target_sr] multi_modal_data["audio"] = (audio, target_sr) - # Calculate number of tokens after convolutions - num_tokens = (audio.size // 80 - 1) // 2 + 1 - num_tokens = (num_tokens - 2) // 2 + 1 + # Calculate number of tokens after feature extraction and convolutions + num_tokens = (audio.size // 160 - 1) // 2 + 1 # Pre-allocate placeholder tokens in encoder sequence inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens return inputs From d73e0049b9a1ad7bdaf64e7fd0a576fb91c29d18 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 1 Jan 2025 01:12:16 +0000 Subject: [PATCH 38/44] fix repeating issue --- vllm/model_executor/models/whisper.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 4207c4249aae5..83d2ba867074c 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -598,13 +598,9 @@ def input_processor_for_whisper(ctx: InputContext, inputs): processor = cached_get_processor(ctx.model_config.model) target_sr = processor.feature_extractor.sampling_rate audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) - if audio.size > 30 * target_sr: - # Truncate audio to 30 seconds - audio = audio[:30 * target_sr] multi_modal_data["audio"] = (audio, target_sr) - # Calculate number of tokens after feature extraction and convolutions - num_tokens = (audio.size // 160 - 1) // 2 + 1 # Pre-allocate placeholder tokens in encoder sequence + num_tokens = ctx.model_config.hf_config.max_source_positions inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens return inputs @@ -628,7 +624,6 @@ def input_mapper_for_whisper( kwargs = processor(audios, sampling_rate=sampling_rate, - padding=False, return_tensors="pt") kwargs["input_features"] = kwargs["input_features"].squeeze(0) kwargs["input_features"] = kwargs["input_features"].to(torch.float16) From 9672af2cede2cd55624abcdd9933982421a2f830 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 1 Jan 2025 01:58:46 +0000 Subject: [PATCH 39/44] fix tests --- .../encoder_decoder/audio/test_whisper.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py index 81761d2a2b679..7789d7afeccdf 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -34,21 +34,21 @@ EXPECTED = { "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its fleece was white as" - " snow, and everywhere that Mary went the lamb would shun it all.", - " And the old one pitch on the way to Edgar Martinez swung on the line" - " down the left field line for Obeysmith. Here comes Joy. Here is" - " Jorgen at third base. They're gonna wave him in. The throw to the" + " of practical poetry. Mary had a little lamb, its fleece was quite as" + " slow, and everywhere that Mary went the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez swung on the line" + " down the left field line for Obeyshev. Here comes Joy. Here is" + " Jorgen at third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh" " my." ], "openai/whisper-large-v3": [ - " The first words I spoke in the original phonograph. A little piece" - " of practical poetry. Mary had a little lamb, its fleece was white as" - " snow, and everywhere that Mary went, the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line," - " down the left field line for a base hit. Here comes Joy. Here is" + " The first words I spoke in the original phonograph, a little piece" + " of practical poetry. Mary had a little lamb, its feet were quite as" + " slow, and everywhere that Mary went, the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line." + " Now the left field line for a base hit. Here comes Joy. Here is" " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," From 127f46ef2879a02abaeb6732398d2f0d44142924 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 2 Jan 2025 16:49:44 +0000 Subject: [PATCH 40/44] clean --- .buildkite/test-pipeline.yaml | 2 ++ .../encoder_decoder/audio/test_whisper.py | 2 ++ vllm/model_executor/models/whisper.py | 19 ++++++++++--------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c6f8316412e2f..193d8507f45fd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -363,12 +363,14 @@ steps: - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language - tests/models/embedding/vision_language + - tests/models/encoder_decoder/audio - tests/models/encoder_decoder/vision_language commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model + - pytest -v -s models/encoder_decoder/audio -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py index 7789d7afeccdf..a66dee95dc6c8 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -88,6 +88,7 @@ def run_test( @fork_new_process_for_each_test +@pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-medium", "openai/whisper-large-v3"]) @pytest.mark.parametrize("enforce_eager", [True, False]) @@ -96,6 +97,7 @@ def test_models(model, enforce_eager) -> None: @multi_gpu_test(num_gpus=2) +@pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 83d2ba867074c..cb54b4c3ba663 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -298,9 +298,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): + if hidden_states.isinf().any() or hidden_states.isnan().any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, @@ -580,11 +578,14 @@ def get_max_whisper_audio_tokens(ctx: InputContext) -> int: def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): assert mm_counts["audio"] == 1 - sample_rate = 16000 - num_tokens = ctx.model_config.hf_config.max_source_positions + num_tokens = get_max_whisper_audio_tokens(ctx) + processor = cached_get_processor(ctx.model_config.model) + chunk_length = processor.feature_extractor.chunk_length + sampling_rate = processor.feature_extractor.sampling_rate + num_samples = chunk_length * sampling_rate return DummyData( SequenceData.from_prompt_token_counts((0, num_tokens)), - {"audio": [(np.zeros(30 * sample_rate), sample_rate)]}, + {"audio": [(np.zeros(num_samples), sampling_rate)]}, ) @@ -600,7 +601,7 @@ def input_processor_for_whisper(ctx: InputContext, inputs): audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) multi_modal_data["audio"] = (audio, target_sr) # Pre-allocate placeholder tokens in encoder sequence - num_tokens = ctx.model_config.hf_config.max_source_positions + num_tokens = get_max_whisper_audio_tokens(ctx) inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens return inputs @@ -625,8 +626,8 @@ def input_mapper_for_whisper( kwargs = processor(audios, sampling_rate=sampling_rate, return_tensors="pt") - kwargs["input_features"] = kwargs["input_features"].squeeze(0) - kwargs["input_features"] = kwargs["input_features"].to(torch.float16) + kwargs["input_features"] = kwargs["input_features"].squeeze(0).to( + ctx.model_config.dtype) return MultiModalKwargs(kwargs) From edfec27b90c79acf3bef6e239ac1e129db775e6f Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 2 Jan 2025 18:19:13 +0000 Subject: [PATCH 41/44] reduce tests --- .buildkite/test-pipeline.yaml | 4 +- .../{audio => audio_language}/__init__.py | 0 .../{audio => audio_language}/test_whisper.py | 50 +++++++++++++++---- 3 files changed, 41 insertions(+), 13 deletions(-) rename tests/models/encoder_decoder/{audio => audio_language}/__init__.py (100%) rename tests/models/encoder_decoder/{audio => audio_language}/test_whisper.py (58%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 193d8507f45fd..529daf54faecf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -363,14 +363,14 @@ steps: - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language - tests/models/embedding/vision_language - - tests/models/encoder_decoder/audio + - tests/models/encoder_decoder/audio_language - tests/models/encoder_decoder/vision_language commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model - - pytest -v -s models/encoder_decoder/audio -m core_model + - pytest -v -s models/encoder_decoder/audio_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model diff --git a/tests/models/encoder_decoder/audio/__init__.py b/tests/models/encoder_decoder/audio_language/__init__.py similarity index 100% rename from tests/models/encoder_decoder/audio/__init__.py rename to tests/models/encoder_decoder/audio_language/__init__.py diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio_language/test_whisper.py similarity index 58% rename from tests/models/encoder_decoder/audio/test_whisper.py rename to tests/models/encoder_decoder/audio_language/test_whisper.py index a66dee95dc6c8..4eb1cfee381dc 100644 --- a/tests/models/encoder_decoder/audio/test_whisper.py +++ b/tests/models/encoder_decoder/audio_language/test_whisper.py @@ -32,6 +32,28 @@ ] EXPECTED = { + "openai/whisper-tiny": [ + " He has birth words I spoke in the original corner of that. And a" + " little piece of black coat poetry. Mary had a little sandwich," + " sweet, with white and snow. And everyone had it very went the last" + " would sure to go.", + " >> And the old one, fit John the way to Edgar Martinez. >> One more" + " to line down the field line for our base camp. Here comes joy. Here" + " is June and the third base. They're going to wave him in. The throw" + " to the plate will be late. The Mariners are going to play for the" + " American League Championship. I don't believe it. It just continues" + " by all five." + ], + "openai/whisper-small": [ + " The first words I spoke in the original pornograph. A little piece" + " of practical poetry. Mary had a little lamb, its fleece was quite a" + " slow, and everywhere that Mary went the lamb was sure to go.", + " And the old one pitch on the way to Edgar Martinez one month. Here" + " comes joy. Here is Junior to third base. They're gonna wave him" + " in. The throw to the plate will be late. The Mariners are going to" + " play for the American League Championship. I don't believe it. It" + " just continues. My, oh my." + ], "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" " of practical poetry. Mary had a little lamb, its fleece was quite as" @@ -53,6 +75,17 @@ " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," " my." + ], + "openai/whisper-large-v3-turbo": [ + " The first words I spoke in the original phonograph, a little piece" + " of practical poetry. Mary had a little lamb, its streets were quite" + " as slow, and everywhere that Mary went the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line" + " down the left field line for a base hit. Here comes Joy. Here is" + " Junior to third base. They're going to wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh," + " my." ] } @@ -60,7 +93,6 @@ def run_test( model: str, *, - enforce_eager: bool, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ) -> None: @@ -71,7 +103,6 @@ def run_test( model=model, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=enforce_eager, ) sampling_params = SamplingParams( @@ -90,20 +121,17 @@ def run_test( @fork_new_process_for_each_test @pytest.mark.core_model @pytest.mark.parametrize("model", - ["openai/whisper-medium", "openai/whisper-large-v3"]) -@pytest.mark.parametrize("enforce_eager", [True, False]) -def test_models(model, enforce_eager) -> None: - run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1) + ["openai/whisper-small", + "openai/whisper-large-v3-turbo"]) +def test_models(model) -> None: + run_test(model, tensor_parallel_size=1) @multi_gpu_test(num_gpus=2) @pytest.mark.core_model -@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) -@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -def test_models_distributed(model, enforce_eager, - distributed_executor_backend) -> None: +def test_models_distributed(model, distributed_executor_backend) -> None: run_test(model, - enforce_eager=enforce_eager, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend) From ba308867f4a145878106c027b0165af46e0cf5a3 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 2 Jan 2025 18:22:46 +0000 Subject: [PATCH 42/44] format --- tests/models/encoder_decoder/audio_language/test_whisper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/encoder_decoder/audio_language/test_whisper.py b/tests/models/encoder_decoder/audio_language/test_whisper.py index 4eb1cfee381dc..eb238c5332139 100644 --- a/tests/models/encoder_decoder/audio_language/test_whisper.py +++ b/tests/models/encoder_decoder/audio_language/test_whisper.py @@ -120,9 +120,8 @@ def run_test( @fork_new_process_for_each_test @pytest.mark.core_model -@pytest.mark.parametrize("model", - ["openai/whisper-small", - "openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize( + "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) def test_models(model) -> None: run_test(model, tensor_parallel_size=1) From dbd21a428bb133b3fc82f7bb848cc8d2aff50a2c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 2 Jan 2025 17:17:26 -0500 Subject: [PATCH 43/44] Update offline_inference_whisper.py --- examples/offline_inference_whisper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py index 35326a435c76e..087ad4376fb2e 100644 --- a/examples/offline_inference_whisper.py +++ b/examples/offline_inference_whisper.py @@ -3,8 +3,6 @@ from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset -dtype = "float" - # Create a Whisper encoder/decoder model instance llm = LLM( model="openai/whisper-large-v3", From ab674fa56144a033b6542cd768a02b65e160a8af Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 3 Jan 2025 04:03:59 +0000 Subject: [PATCH 44/44] move whisper test registry --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 4373317d60891..dcb8bfa0f9510 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -128,7 +128,6 @@ class _HfExamplesInfo: # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), - "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 @@ -205,6 +204,7 @@ class _HfExamplesInfo: "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } _SPECULATIVE_DECODING_EXAMPLE_MODELS = {