Skip to content

Commit

Permalink
[Model] merged input processor for Phi-3-Vision models (vllm-project#…
Browse files Browse the repository at this point in the history
…10977)

Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
Isotr0py and DarkLight1337 authored Dec 9, 2024
1 parent ca87149 commit a811dd6
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 409 deletions.
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782)
completion_tokens=10, prompt_tokens=775, total_tokens=785)

message = choice.message
message = chat_completion.choices[0].message
Expand Down Expand Up @@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782)
completion_tokens=10, prompt_tokens=775, total_tokens=785)

message = choice.message
message = chat_completion.choices[0].message
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 762
assert embeddings["usage"]["total_tokens"] == 762
assert embeddings["usage"]["prompt_tokens"] == 765
assert embeddings["usage"]["total_tokens"] == 765
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import Optional

import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.inputs import InputContext, InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -17,15 +15,9 @@

# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_phi3v():
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
return input_processor_for_phi3v


@pytest.fixture()
def dummy_data_for_phi3v():
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
return dummy_data_for_phi3v
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor


@pytest.fixture()
Expand All @@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens():
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops", [4, 16, None])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
num_crops: Optional[int]):
"""Ensure that the [default] input mapper handles num_crops properly."""
# We pass the processor kwargs here since for this model, we fall back to
# the default mapper; this will fall back to the HF mapper and forward
# mm_processor_kwargs to it.
mm_processor_kwargs = {
"num_crops": num_crops
} if num_crops is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)

hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)

vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)

assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
assert torch.all(
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])

# For pixel values, the second axis should be the num_crops + 1
# for the rescaled original image. The default value in VLLM falls
# back to the HF config, which is why we compare to the processor num_crops
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
Expand Down Expand Up @@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
(4, 781, 1),
(4, 781, 2),
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

dummy_data = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img,num_imgs",
[
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
# the default num_crops of phi-3.5-vision is 4
(None, 757, 2),
(None, 757, 2),
])
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model: str, num_crops: Optional[int],
expected_toks_per_img: int, num_imgs: int):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
Expand All @@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v,
tokenizer_name=model,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
mm_data = {"image": images}
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}

processed_inputs = input_processor_for_phi3v(ctx,
inputs,
num_crops=num_crops)
processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
Expand Down
Loading

0 comments on commit a811dd6

Please sign in to comment.