From 8f37be38ebfe0295a4925837c501c87149997a4d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 7 Jan 2025 18:25:02 +0800 Subject: [PATCH] [Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calculation (#11800) Signed-off-by: DarkLight1337 --- requirements-test.in | 1 + requirements-test.txt | 4 + .../processing/test_llava_next.py | 129 +++++++++++++++--- .../processing/test_llava_onevision.py | 129 +++++++++++++++--- vllm/model_executor/models/llava_next.py | 37 ++--- vllm/model_executor/models/llava_onevision.py | 42 +++--- 6 files changed, 253 insertions(+), 89 deletions(-) diff --git a/requirements-test.in b/requirements-test.in index fb4179c3d8423..4b4dc376d1fa5 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -13,6 +13,7 @@ einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests peft +pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests diff --git a/requirements-test.txt b/requirements-test.txt index 3771577fe8ed0..f576e42afcbbf 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -48,6 +48,8 @@ botocore==1.35.57 # awscli # boto3 # s3transfer +bounded-pool-executor==0.0.3 + # via pqdm buildkite-test-collector==0.1.9 # via -r requirements-test.in certifi==2024.8.30 @@ -342,6 +344,8 @@ pooch==1.8.2 # via librosa portalocker==2.10.1 # via sacrebleu +pqdm==0.2.0 + # via -r requirements-test.in propcache==0.2.0 # via yarl protobuf==5.28.3 diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/decoder_only/vision_language/processing/test_llava_next.py index 37a6d334ee60c..9fa6a8a10a0f9 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_next.py @@ -1,8 +1,13 @@ +import itertools +from functools import partial + import pytest from PIL import Image +from pqdm.threads import pqdm from transformers import AutoTokenizer from vllm.inputs import InputProcessingContext +from vllm.multimodal.parse import ImageSize from ....utils import build_model_context @@ -15,20 +20,69 @@ def processor_for_llava_next(): return LlavaNextMultiModalProcessor +def _validate_image_prompt_replacements_one( + processor, + num_imgs: int, + failed_size_excs: list[tuple[ImageSize, Exception]], + image_size: ImageSize, +) -> None: + prompt = "" * num_imgs + image = Image.new("RGB", size=image_size) + mm_data = {"image": [image] * num_imgs} + + try: + # The processor will throw an error if there is a mismatch + # in the prompt replacements + processed_inputs = processor.apply(prompt, mm_data, {}) + + image_placeholders = processed_inputs["mm_placeholders"]["image"] + assert len(image_placeholders) == num_imgs + + first_placeholder = image_placeholders[0] + + # NOTE: There is a BOS token + assert first_placeholder["offset"] == 1 + assert first_placeholder["length"] == ( + len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + + except Exception as exc: + failed_size_excs.append((image_size, exc)) + + +def _test_image_prompt_replacements( + processor, + *, + num_imgs: int, + image_sizes: list[ImageSize], +) -> None: + """ + Ensure LlavaNextMultiModalProcessor + handles prompt replacement properly for input images. + """ + failed_size_excs = list[tuple[ImageSize, Exception]]() + + validate_one = partial( + _validate_image_prompt_replacements_one, + processor, + num_imgs, + failed_size_excs, + ) + pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") + + if failed_size_excs: + msg = "Found failing image sizes:" \ + + "\n========\n".join(f"[{size}]\n{exc}" + for size, exc in failed_size_excs) + raise AssertionError(msg) + + @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) -@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488), - (488, 183), (198, 176), (176, 198), - (161, 184), (184, 161)]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements( +def test_processor_prompt_replacements_regression( processor_for_llava_next, model_id: str, - image_size: tuple[int, int], num_imgs: int, ): - """ - Ensure LlavaNextMultiModalProcessor handles prompt replacement properly. - """ ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, @@ -37,22 +91,55 @@ def test_processor_prompt_replacements( ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) + processor = processor_for_llava_next(ctx) + + image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), + (488, 183), (2560, 1669)] + image_sizes = [ + size for w, h in image_ratios + for size in [ImageSize(w, h), ImageSize(h, w)] + ] + + _test_image_prompt_replacements( + processor, + num_imgs=num_imgs, + image_sizes=image_sizes, + ) - # Build the image str / prompt based on the number of images we pass - prompt = "" * num_imgs - mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs} - # The processor will throw an error if there is a mismatch - # in the prompt replacements +@pytest.mark.skip("This test takes around 2 hours to run. " + "Comment this out to run it manually.") +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize("num_imgs", [1]) +def test_processor_prompt_replacements_all( + processor_for_llava_next, + model_id: str, + num_imgs: int, +): + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) processor = processor_for_llava_next(ctx) - processed_inputs = processor.apply(prompt, mm_data, {}) - image_placeholders = processed_inputs["mm_placeholders"]["image"] - assert len(image_placeholders) == num_imgs + seen_aspect_ratios = set[float]() + image_sizes = list[ImageSize]() - first_placeholder = image_placeholders[0] + # The aspect ratio of the grid layout is between 1 and 2 + # NOTE: Assumes that feature size calculation is the same if we + # swap the width and height of the image + for w, h in itertools.product(range(64, 1024), repeat=2): + aspect_ratio = w / h + if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios: + image_sizes.append(ImageSize(w, h)) + seen_aspect_ratios.add(aspect_ratio) - # NOTE: There is a BOS token - assert first_placeholder["offset"] == 1 - assert first_placeholder["length"] == ( - len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + _test_image_prompt_replacements( + processor, + num_imgs=num_imgs, + image_sizes=image_sizes, + ) diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py index ed3e2db799be7..d4cdffa210b6d 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py @@ -1,8 +1,13 @@ +import itertools +from functools import partial + import pytest from PIL import Image +from pqdm.threads import pqdm from transformers import AutoTokenizer from vllm.inputs import InputProcessingContext +from vllm.multimodal.parse import ImageSize from ....utils import build_model_context @@ -15,22 +20,68 @@ def processor_for_llava_onevision(): return LlavaOnevisionMultiModalProcessor +def _validate_image_prompt_replacements_one( + processor, + num_imgs: int, + failed_size_excs: list[tuple[ImageSize, Exception]], + image_size: ImageSize, +) -> None: + prompt = "" * num_imgs + image = Image.new("RGB", size=image_size) + mm_data = {"image": [image] * num_imgs} + + try: + # The processor will throw an error if there is a mismatch + # in the prompt replacements + processed_inputs = processor.apply(prompt, mm_data, {}) + + image_placeholders = processed_inputs["mm_placeholders"]["image"] + assert len(image_placeholders) == num_imgs + + first_placeholder = image_placeholders[0] + + assert first_placeholder["offset"] == 0 + assert first_placeholder["length"] == len( + processed_inputs["prompt_token_ids"]) // num_imgs + except Exception as exc: + failed_size_excs.append((image_size, exc)) + + +def _test_image_prompt_replacements( + processor, + *, + num_imgs: int, + image_sizes: list[ImageSize], +) -> None: + """ + Ensure LlavaOnevisionMultiModalProcessor + handles prompt replacement properly for input images. + """ + failed_size_excs = list[tuple[ImageSize, Exception]]() + + validate_one = partial( + _validate_image_prompt_replacements_one, + processor, + num_imgs, + failed_size_excs, + ) + pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") + + if failed_size_excs: + msg = "Found failing image sizes:" \ + + "\n========\n".join(f"[{size}]\n{exc}" + for size, exc in failed_size_excs) + raise AssertionError(msg) + + @pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) -@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488), - (488, 183), (198, 176), (176, 198), - (161, 184), (184, 161)]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements( +def test_processor_prompt_replacements_regression( processor_for_llava_onevision, model_id: str, - image_size: tuple[int, int], num_imgs: int, ): - """ - Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement - properly. - """ ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, @@ -39,22 +90,56 @@ def test_processor_prompt_replacements( ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) + processor = processor_for_llava_onevision(ctx) - # Build the image str / prompt based on the number of images we pass - prompt = "" * num_imgs - mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs} + image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), + (488, 183), (2560, 1669)] + image_sizes = [ + size for w, h in image_ratios + for size in [ImageSize(w, h), ImageSize(h, w)] + ] + + _test_image_prompt_replacements( + processor, + num_imgs=num_imgs, + image_sizes=image_sizes, + ) - # The processor will throw an error if there is a mismatch - # in the prompt replacements + +@pytest.mark.skip("This test takes around 2 hours to run. " + "Comment this out to run it manually.") +@pytest.mark.parametrize("model_id", + ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("num_imgs", [1]) +def test_processor_prompt_replacements_all( + processor_for_llava_onevision, + model_id: str, + num_imgs: int, +): + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) processor = processor_for_llava_onevision(ctx) - processed_inputs = processor.apply(prompt, mm_data, {}) - image_placeholders = processed_inputs["mm_placeholders"]["image"] - assert len(image_placeholders) == num_imgs + seen_aspect_ratios = set[float]() + image_sizes = list[ImageSize]() - first_placeholder = image_placeholders[0] + # The aspect ratio of the grid layout is between 1 and 6 + # NOTE: Assumes that feature size calculation is the same if we + # swap the width and height of the image + for w, h in itertools.product(range(64, 1024), repeat=2): + aspect_ratio = w / h + if 1 <= aspect_ratio <= 6 and aspect_ratio not in seen_aspect_ratios: + image_sizes.append(ImageSize(w, h)) + seen_aspect_ratios.add(aspect_ratio) - # NOTE: There is a BOS token - assert first_placeholder["offset"] == 0 - assert first_placeholder["length"] == len( - processed_inputs["prompt_token_ids"]) // num_imgs + _test_image_prompt_replacements( + processor, + num_imgs=num_imgs, + image_sizes=image_sizes, + ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 258352416d4a7..815456dac2a2f 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -2,7 +2,6 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor @@ -74,7 +73,7 @@ def _get_hf_config(self) -> LlavaNextLikeConfig: def _get_hf_processor(self): return self.ctx.get_hf_processor(LlavaNextProcessor) - # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 def _get_num_image_tokens( self, *, @@ -111,7 +110,7 @@ def _get_num_image_tokens( return unpadded_feature_size + newline_feature_size + base_feature_size - # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 def _get_num_unpadded_features( self, *, @@ -121,29 +120,23 @@ def _get_num_unpadded_features( num_patch_height: int, num_patch_width: int, ) -> tuple[int, int]: - # NOTE: Use float32 to remain consistent with HF output - current_height_f = np.float32(npatches * num_patch_height) - current_width_f = np.float32(npatches * num_patch_width) + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width - original_width_f = np.float32(original_width) - original_height_f = np.float32(original_height) + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height - original_aspect_ratio = original_width_f / original_height_f - current_aspect_ratio = current_width_f / current_height_f - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width_f / original_width_f - new_height = int(original_height_f * scale_factor) - padding = (current_height_f - new_height) // 2 - current_height_f -= 2 * padding + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) else: - scale_factor = current_height_f / original_height_f - new_width = int(original_width_f * scale_factor) - padding = (current_width_f - new_width) // 2 - current_width_f -= 2 * padding + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) - unpadded_features = int(current_height_f * current_width_f) - newline_features = int(current_height_f) + unpadded_features = current_height * current_width + newline_features = current_height return (unpadded_features, newline_features) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 62dae74e377be..b5e3edba1f01c 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,6 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.nn as nn from transformers import (BatchFeature, LlavaOnevisionConfig, @@ -98,6 +97,8 @@ def _get_hf_config(self) -> LlavaOnevisionLikeConfig: def _get_hf_processor(self): return self.ctx.get_hf_processor(LlavaOnevisionProcessor) + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 + # with additional logic afterwards taken from LlavaOnevisionProcessor def _get_num_unpadded_features( self, *, @@ -107,35 +108,28 @@ def _get_num_unpadded_features( num_patch_height: int, num_patch_width: int, ) -> tuple[int, int]: - # NOTE: Use float32 to remain consistent with HF output - current_height_f = np.float32(npatches * num_patch_height) - current_width_f = np.float32(npatches * num_patch_width) + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width - original_width_f = np.float32(original_width) - original_height_f = np.float32(original_height) + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height - original_aspect_ratio = original_width_f / original_height_f - current_aspect_ratio = current_width_f / current_height_f - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width_f / original_width_f - new_height = int(original_height_f * scale_factor) - padding = (current_height_f - new_height) // 2 - current_height_f -= 2 * padding + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) else: - scale_factor = current_height_f / original_height_f - new_width = int(original_width_f * scale_factor) - padding = (current_width_f - new_width) // 2 - current_width_f -= 2 * padding + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) - unpadded_features = int(current_height_f * current_width_f) - newline_features = int(current_height_f) + unpadded_features = current_height * current_width + newline_features = current_height - ratio = math.sqrt(current_height_f * current_width_f / - (9 * npatches**2)) + ratio = math.sqrt(current_height * current_width / (9 * npatches**2)) if ratio > 1.1: - height_factor = int(current_height_f // ratio) - width_factor = int(current_width_f // ratio) + height_factor = int(current_height // ratio) + width_factor = int(current_width // ratio) unpadded_features = height_factor * width_factor newline_features = height_factor