Skip to content

Commit

Permalink
[VLM] Support caching in merged multi-modal processor (#11396)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 27, 2024
1 parent 5ce4627 commit 1014180
Show file tree
Hide file tree
Showing 20 changed files with 1,455 additions and 448 deletions.
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def linkcode_resolve(domain, info):

# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"blake3",
"compressed_tensors",
"cpuinfo",
"cv2",
Expand All @@ -207,7 +208,7 @@ def linkcode_resolve(domain, info):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"xgrammar",
"librosa",
"soundfile",
"gguf",
Expand Down
24 changes: 12 additions & 12 deletions docs/source/design/multimodal/multimodal_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,39 +45,39 @@ adding_multimodal_plugin
### Base Classes

```{eval-rst}
.. autodata:: vllm.multimodal.NestedTensors
.. automodule:: vllm.multimodal.base
:members:
:show-inheritance:
```

```{eval-rst}
.. autodata:: vllm.multimodal.BatchedTensorInputs
```
### Input Classes

```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
.. automodule:: vllm.multimodal.inputs
:members:
:show-inheritance:
```

```{eval-rst}
.. autodata:: vllm.multimodal.MultiModalDataDict
```
### Audio Classes

```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalKwargs
.. automodule:: vllm.multimodal.audio
:members:
:show-inheritance:
```

### Image Classes

```{eval-rst}
.. autoclass:: vllm.multimodal.MultiModalPlugin
.. automodule:: vllm.multimodal.image
:members:
:show-inheritance:
```

### Image Classes
### Video Classes

```{eval-rst}
.. automodule:: vllm.multimodal.image
.. automodule:: vllm.multimodal.video
:members:
:show-inheritance:
```
3 changes: 1 addition & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
```

```{note}
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```

```{note}
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 765
assert embeddings.usage.total_tokens == 765
assert embeddings.usage.prompt_tokens == 764
assert embeddings.usage.total_tokens == 764
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens():


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225),
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
Expand Down
4 changes: 3 additions & 1 deletion tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=48)],
),
"glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
Expand All @@ -212,7 +213,7 @@
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
patch_hf_runner=model_utils.glm_patch_hf_runner,
marks=[large_gpu_mark(min_gb=48)],
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models = [
Expand Down Expand Up @@ -261,6 +262,7 @@
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
Expand Down
209 changes: 203 additions & 6 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from functools import partial
from typing import cast

import numpy as np
import pytest

from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from PIL import Image

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby

Expand Down Expand Up @@ -457,6 +465,7 @@ def test_find_replace_tokens(
),
]
)
# yapf: enable
def test_iter_placeholders(
repl_by_key,
prompt,
Expand All @@ -475,11 +484,199 @@ def test_iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
{key: 3 for key in repl_by_key},
))
{key: 3
for key in repl_by_key},
))

# Only displayed on error
print("result:", result)

# Manually constructed results
assert result == expected


def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
w, h = rng.randint(min_wh, max_wh, size=(2, ))
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
return Image.fromarray(arr)


def _rand_video(
rng: np.random.RandomState,
min_frames: int,
max_frames: int,
min_wh: int,
max_wh: int,
):
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2

w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)


def _rand_audio(
rng: np.random.RandomState,
min_len: int,
max_len: int,
sr: int,
):
audio_len = rng.randint(min_len, max_len)
return rng.rand(audio_len), sr


def _test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
else:
hf_overrides = {}

model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=True,
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)

baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)

rng = np.random.RandomState(0)

input_to_hit = {
"image": Image.new("RGB", size=(128, 128)),
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
"audio": (np.zeros((512, )), 16000),
}
input_factory = {
"image":
partial(_rand_img, rng, min_wh=128, max_wh=256),
"video":
partial(_rand_video,
rng,
min_frames=2,
max_frames=8,
min_wh=128,
max_wh=256),
"audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
"image": 3,
"video": 3,
"audio": 3,
}

for batch_idx in range(num_batches):
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))]
for k in modalities
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text

# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
for k in list(mm_data.keys()):
if not mm_data[k]:
del mm_data[k]
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]

baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")


# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("llava-hf/llava-1.5-7b-hf", {"image"}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
("mistral-community/pixtral-12b", {"image"}),
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
("fixie-ai/ultravox-v0_3", {"audio"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)


# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401

AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)
Loading

0 comments on commit 1014180

Please sign in to comment.