Skip to content

Commit

Permalink
[VLM] Move supported limits and max tokens to merged multi-modal proc…
Browse files Browse the repository at this point in the history
…essor (#11669)

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
DarkLight1337 and Isotr0py authored Jan 1, 2025
1 parent 7300144 commit a115ac4
Show file tree
Hide file tree
Showing 16 changed files with 340 additions and 350 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID

from .....conftest import _ImageAssets
Expand All @@ -20,42 +20,6 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor


@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)

assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img",
Expand All @@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
return Qwen2VLMultiModalProcessor


@pytest.fixture()
def get_max_qwen2_vl_image_tokens():
from vllm.model_executor.models.qwen2_vl import (
get_max_qwen2_vl_image_tokens)
return get_max_qwen2_vl_image_tokens


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)),
Expand Down Expand Up @@ -82,6 +49,7 @@ def test_processor_override(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
Expand Down
14 changes: 8 additions & 6 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ def _test_processing_cache_correctness(
else:
hf_overrides = {}

limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}

model_config = ModelConfig(
model_id,
task="auto",
Expand All @@ -548,6 +553,7 @@ def _test_processing_cache_correctness(
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

Expand Down Expand Up @@ -580,18 +586,14 @@ def _test_processing_cache_correctness(
min_wh=128,
max_wh=256),
"audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
}

for batch_idx in range(num_batches):
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))]
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
for k in modalities
}

Expand Down
8 changes: 1 addition & 7 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,7 @@ def dummy_data_for_profiling(
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)

mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)

dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
dummy_data = processor.get_dummy_data(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
75 changes: 42 additions & 33 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
import torch.nn as nn
Expand All @@ -9,7 +9,6 @@
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -87,8 +86,8 @@ def __init__(
def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.BoolTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

vit_oup = self.vision_model(
Expand All @@ -100,7 +99,8 @@ def forward(

return vit_oup, image_atts

def _create_patch_attention_mask(self, pixel_mask):
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None

Expand All @@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

def _create_image_attention_mask(self, patch_attention_mask):
def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None

Expand All @@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask):

class FFN(nn.Module):

def __init__(self, embed_dim, ff_dim, output_dim):
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.act = get_act_fn("gelu_new")

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_out(hidden_states)
Expand All @@ -140,7 +141,7 @@ def forward(self, hidden_states):

class CrossAttention(nn.Module):

def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
Expand All @@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):

self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)

self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)

def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)

Expand All @@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False):

attn_output = attn_output.permute(1, 0, 2)

if add_residual:
attn_output = hidden_states + self.dropout(
self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))
attn_output = self.linear(attn_output)

return attn_output

Expand Down Expand Up @@ -201,14 +202,14 @@ class AriaProjector(nn.Module):

def __init__(
self,
patch_to_query_dict,
embed_dim,
num_heads,
kv_dim,
ff_dim,
output_dim,
norm_layer=nn.LayerNorm,
):
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
Expand All @@ -224,7 +225,11 @@ def __init__(
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)

def forward(self, x, attn_mask=None):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)

Expand Down Expand Up @@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig):
)


def get_max_aria_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaMultiModalProcessor(BaseMultiModalProcessor):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())

class AriaMultiModalProcessor(BaseMultiModalProcessor):
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}

def _get_mm_fields_config(
self,
Expand All @@ -468,13 +478,13 @@ def _get_prompt_replacements(
hf_config = self.ctx.get_hf_config()
image_token_id = hf_config.image_token_index

max_image_tokens = get_max_aria_image_tokens(self.ctx)
num_image_tokens = self._get_num_image_tokens()

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
replacement=[image_token_id] * num_image_tokens,
)
]

Expand Down Expand Up @@ -504,7 +514,6 @@ def _get_dummy_mm_inputs(
)


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Expand Down
Loading

0 comments on commit a115ac4

Please sign in to comment.