Skip to content

Commit

Permalink
[3/N] Support and implement merged input processor for LLaVA model (v…
Browse files Browse the repository at this point in the history
…llm-project#10676)

Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
2 people authored and weilong.yu committed Dec 13, 2024
1 parent 91efa90 commit d5be8c4
Show file tree
Hide file tree
Showing 10 changed files with 626 additions and 421 deletions.
49 changes: 3 additions & 46 deletions tests/multimodal/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from transformers import LlavaNextImageProcessor

from vllm.config import ModelConfig
from vllm.multimodal import MultiModalRegistry
Expand All @@ -14,49 +14,6 @@ def mm_registry():
return MultiModalRegistry()


@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, CLIPImageProcessor)

model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
limit_mm_per_prompt={"image": 1},
)

mm_registry.init_mm_limits_per_prompt(model_config)

for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)

hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)

assert hf_result.keys() == vllm_result.keys()
for key, hf_tensor in hf_result.items():
hf_arr: np.ndarray = hf_tensor.numpy()
vllm_arr: np.ndarray = vllm_result[key].numpy()

assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"


@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
Expand Down Expand Up @@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"

model_config = ModelConfig(
model=MODEL_NAME,
Expand Down Expand Up @@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"

model_config = ModelConfig(
model=MODEL_NAME,
Expand Down
Loading

0 comments on commit d5be8c4

Please sign in to comment.