Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Update multi-modal processor to support Mantis(LLaVA) model #10711

Merged
merged 24 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7b6c4f1
Add `get_dummy_data` to `MultiModalProcessor`; fix and test `iter_pla…
DarkLight1337 Nov 26, 2024
de8332a
Use merged processor for llava model
DarkLight1337 Nov 26, 2024
8b6804e
format
DarkLight1337 Nov 26, 2024
26e3fdf
Fix typo
DarkLight1337 Nov 26, 2024
93d27bc
Enable the test to pass on V1
DarkLight1337 Nov 26, 2024
d697241
Handle embedding inputs
DarkLight1337 Nov 26, 2024
ca11cc9
format
DarkLight1337 Nov 26, 2024
c32cba9
Merge branch 'main' into llava-mm-processor
DarkLight1337 Nov 27, 2024
6c5c9ca
Fix wrong ndim
DarkLight1337 Nov 27, 2024
0194324
Factor out `merge_placeholders`
DarkLight1337 Nov 27, 2024
09618d0
Fix placeholder maps handling on V0
DarkLight1337 Nov 27, 2024
5501458
Remove unused dummy data code
DarkLight1337 Nov 27, 2024
f3673c7
Update dummy model
DarkLight1337 Nov 27, 2024
37bc008
Enable overriding hf processor and tokenizer; fix `_apply_prompt_repl…
DarkLight1337 Nov 27, 2024
4805a9e
Improve error handling in `_resolve_matches`; merge matches directly
DarkLight1337 Nov 27, 2024
8539008
Avoid hashing
DarkLight1337 Nov 27, 2024
1e82a4a
Support and test Mantis model
DarkLight1337 Nov 27, 2024
cfbece4
Update docs
DarkLight1337 Nov 27, 2024
af68652
Merge branch 'main' into mantis
DarkLight1337 Nov 29, 2024
69aa12d
Merge branch 'main' into mantis
DarkLight1337 Dec 7, 2024
70c87d1
Fix type error
DarkLight1337 Dec 7, 2024
27b276b
Fix redundant code
DarkLight1337 Dec 7, 2024
c10c1cc
Remove convenience function as it makes things more complicated
DarkLight1337 Dec 7, 2024
d77cadd
Fix commands
DarkLight1337 Dec 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
Expand All @@ -377,6 +378,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
Expand Down
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ Text Generation
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- T + I\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc.
-
- ✅︎
* - :code:`LlavaNextForConditionalGeneration`
Expand Down Expand Up @@ -664,6 +664,10 @@ Text Generation
.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.

.. note::
To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.

.. note::
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Expand Down
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str):
return llm, prompt, stop_token_ids


# Mantis
def run_mantis(question: str, modality: str):
assert modality == "image"

llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompt = llama3_template.format(f"{question}\n<image>")

llm = LLM(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str):
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
"mantis": run_mantis,
}


Expand Down
3 changes: 0 additions & 3 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test

# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

# quantization
bitsandbytes>=0.44.0
buildkite-test-collector==0.1.9
Expand Down
30 changes: 22 additions & 8 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"dtype": "half",
"max_tokens": 5,
"tensor_parallel_size": 2,
"model_kwargs": {"device_map": "auto"},
"hf_model_kwargs": {"device_map": "auto"},
"image_size_factors": [(.25, 0.5, 1.0)],
"distributed_executor_backend": (
"ray",
Expand Down Expand Up @@ -108,7 +108,7 @@
"cherry_blossom": "What is in the picture?",
}),
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
Expand Down Expand Up @@ -151,7 +151,7 @@
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
Expand All @@ -177,7 +177,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
# For chameleon, we only compare the sequences
Expand Down Expand Up @@ -281,7 +281,7 @@
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values_videos"
),
auto_cls=AutoModelForVision2Seq,
Expand All @@ -306,6 +306,20 @@
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
),
"mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
max_model_len=4096,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
get_stop_token_ids=lambda tok: [128009],
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
patch_hf_runner=model_utils.mantis_patch_hf_runner,
),
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=VLMTestType.IMAGE,
Expand Down Expand Up @@ -342,7 +356,7 @@
# max_num_seqs=2,
# task="generate",
# # use eager mode for hf runner since phi3v didn't work with flash_attn
# model_kwargs={"_attn_implementation": "eager"},
# hf_model_kwargs={"_attn_implementation": "eager"},
# use_tokenizer_eos=True,
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
# num_logprobs=10,
Expand Down Expand Up @@ -373,7 +387,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
Expand Down Expand Up @@ -438,7 +452,7 @@
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
max_num_seqs=2,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
auto_cls=AutoModelForVision2Seq,
Expand Down
20 changes: 14 additions & 6 deletions tests/models/decoder_only/vision_language/vlm_utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import torch
from PIL.Image import Image
from transformers import AutoTokenizer, BatchEncoding
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from vllm.config import TaskOption

from .....conftest import HfRunner, VllmRunner
from .types import RunnerOutput

Expand All @@ -28,13 +30,15 @@ def run_test(
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]],
vllm_runner_kwargs: Optional[Dict[str, Any]],
hf_model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: str = "auto",
task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1,
Expand All @@ -58,6 +62,9 @@ def run_test(
if stop_str:
vllm_kwargs["stop"] = stop_str

if vllm_runner_kwargs is None:
vllm_runner_kwargs = {}

with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len,
Expand All @@ -67,7 +74,8 @@ def run_test(
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
task=task) as vllm_model:
task=task,
**vllm_runner_kwargs) as vllm_model:
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
vllm_output = vllm_model.generate_greedy_logprobs(
Expand All @@ -78,7 +86,7 @@ def run_test(
dtype=dtype,
auto_cls=auto_cls,
postprocess_inputs=postprocess_inputs,
model_kwargs=model_kwargs)
model_kwargs=hf_model_kwargs)

# Some models need to patch things like the model processor, e.g., internvl
if patch_hf_runner is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
return hf_output_ids, hf_output_str, out_logprobs


def mantis_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [mantis] to compare with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = output_str + "<|eot_id|>"

return output_ids, hf_output_str, out_logprobs


def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
Expand Down Expand Up @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets):


####### postprocessors to run on HF BatchEncoding
def get_key_type_post_processor(
def cast_dtype_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which converts a given key into a
target data type."""
Expand Down Expand Up @@ -418,3 +428,26 @@ def _internvl_generate(
)

return outputs


def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from mantis.models.mllava import MLlavaProcessor

hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name)

orig_generate = hf_model.model.generate
tokenizer = hf_model.processor.tokenizer

def _generate(self, *args, **kwargs):
return orig_generate(
*args,
**kwargs,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
],
)

hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model
19 changes: 12 additions & 7 deletions tests/models/decoder_only/vision_language/vlm_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import (AutoModelForCausalLM, BatchEncoding,
PreTrainedTokenizerBase)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.utils import identity

Expand Down Expand Up @@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple):
class VLMTestInfo(NamedTuple):
"""Holds the configuration for 1+ tests for one model architecture."""

models: Union[List[str]]
models: List[str]
test_type: Union[VLMTestType, Iterable[VLMTestType]]

# Should be None only if this is a CUSTOM_INPUTS test
Expand All @@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple):
enforce_eager: bool = True
max_model_len: int = 1024
max_num_seqs: int = 256
task: str = "auto"
task: TaskOption = "auto"
tensor_parallel_size: int = 1
vllm_runner_kwargs: Optional[Dict[str, Any]] = None

# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None

# Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokeniezr
hf_model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokenizer
use_tokenizer_eos: bool = False
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
# Callable to pass to the HF runner to run on inputs; for now, we also pass
Expand Down Expand Up @@ -164,15 +168,16 @@ def get_non_parametrized_runner_kwargs(self):
"max_num_seqs": self.max_num_seqs,
"task": self.task,
"tensor_parallel_size": self.tensor_parallel_size,
"vllm_runner_kwargs": self.vllm_runner_kwargs,
"hf_output_post_proc": self.hf_output_post_proc,
"vllm_output_post_proc": self.vllm_output_post_proc,
"auto_cls": self.auto_cls,
"use_tokenizer_eos": self.use_tokenizer_eos,
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"hf_model_kwargs": self.hf_model_kwargs,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
}
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class _HfExamplesInfo:
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import torch

from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
create_metadata_for_llava,
dummy_mm_kwargs_for_llava,
LlavaProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
dummy_mm_kwargs_for_llava)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
class MyLlava(LlavaForConditionalGeneration):

def compute_logits(
Expand Down
Loading
Loading