Skip to content

Commit

Permalink
[VLM] Merged multi-modal processor for LLaVA-NeXT (vllm-project#11682)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 2, 2025
1 parent b6087a6 commit 8c38ee7
Show file tree
Hide file tree
Showing 14 changed files with 605 additions and 551 deletions.

This file was deleted.

118 changes: 0 additions & 118 deletions tests/multimodal/test_mapper.py

This file was deleted.

97 changes: 97 additions & 0 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import nullcontext
from functools import partial
from typing import cast
from unittest.mock import MagicMock

import numpy as np
import pytest
Expand Down Expand Up @@ -526,6 +528,100 @@ def _rand_audio(
return rng.rand(audio_len), sr


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
limit_mm_per_prompt = {"image": limit}

model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)

processor = processor_factory(ctx, cache=None)

mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits

if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")

with exc_ctx:
processor._get_and_validate_dummy_mm_counts()


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt = {"image": limit}

model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)

processor = processor_factory(ctx, cache=None)

rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
mm_data = {}
elif num_images == 1:
mm_data = {"image": image}
else:
mm_data = {"image": [image] * num_images}

if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")

with exc_ctx:
processor.apply(
"<image>" * num_images,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)


def _test_processing_cache_correctness(
model_id: str,
modalities: dict[str, bool],
Expand Down Expand Up @@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import torch

from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
LlavaMultiModalProcessor)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):

Expand Down
25 changes: 25 additions & 0 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData

from .vision import VisionEncoderInfo


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -149,6 +151,29 @@ def input_processor_for_clip(
multi_modal_placeholders={"image": ranges})


class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):

def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return get_clip_image_feature_size(self.vision_config)

def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config)

def get_num_patches(self) -> int:
return get_clip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)

def get_image_size(self) -> int:
return self.vision_config.image_size


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_image_target_size(self) -> ImageSize:
return ImageSize(width=target_size["width"],
height=target_size["height"])

def _get_image_grid_size(
def _get_image_feature_grid_size(
self,
*,
image_width: int,
Expand All @@ -99,7 +99,7 @@ def _get_image_grid_size(
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size()

max_ncols, max_nrows = self._get_image_grid_size(
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

ncols, nrows = self._get_image_grid_size(
ncols, nrows = self._get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
Expand Down
Loading

0 comments on commit 8c38ee7

Please sign in to comment.