Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][VLM] V1 support for selected single-image models. #11632

Merged
merged 43 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
425d3c4
batch
ywang96 Dec 29, 2024
1ca9369
Merge branch 'vllm-project:main' into v1-rearch-vlm
ywang96 Dec 29, 2024
8edcc83
blip2
ywang96 Dec 30, 2024
5f76291
chameleon
ywang96 Dec 30, 2024
814f3bd
fix util
ywang96 Dec 30, 2024
efeb999
fuyu
ywang96 Dec 30, 2024
5e568e8
aria
ywang96 Dec 30, 2024
135fd5c
fix profiling
ywang96 Dec 30, 2024
0a8dbe0
update
ywang96 Dec 30, 2024
03f741d
add llava-next
ywang96 Dec 30, 2024
8bce949
revert testing code
ywang96 Dec 30, 2024
bbde414
revert testing code
ywang96 Dec 30, 2024
ea928c6
tweak and clarify
ywang96 Dec 30, 2024
55eada7
clarify
ywang96 Dec 30, 2024
bbd5752
reword
ywang96 Dec 30, 2024
0452b99
Use merged multi-modal processor for blip2 and chameleon
DarkLight1337 Dec 30, 2024
938c0bf
Limit max num seqs
DarkLight1337 Dec 30, 2024
6cc54a7
Update comments
DarkLight1337 Dec 30, 2024
ba713ba
Be more clear
DarkLight1337 Dec 30, 2024
b0efc4f
Merged multi-modal processor for Aria
DarkLight1337 Dec 31, 2024
cdbd969
initialize fuyu merged processor
Isotr0py Dec 31, 2024
48c6946
Clean up
DarkLight1337 Dec 31, 2024
ea76759
Clean up
DarkLight1337 Dec 31, 2024
bc976a7
Try remove mark
DarkLight1337 Dec 31, 2024
f79f79a
Consolidate dummy data code
DarkLight1337 Dec 31, 2024
45ec10c
fix fuyu variant images test
Isotr0py Dec 31, 2024
0926717
Merge branch 'main' into v1-rearch-vlm
DarkLight1337 Dec 31, 2024
0fe561d
Fix some type errors in Pixtral-HF
DarkLight1337 Dec 31, 2024
3512ed6
fix missing flatten_bn in fuyu
Isotr0py Dec 31, 2024
5e0f66c
Update docs
DarkLight1337 Dec 31, 2024
1c243ab
Update docs
DarkLight1337 Dec 31, 2024
09d64f4
Get fuyu processor tests to pass
DarkLight1337 Dec 31, 2024
6d6d71c
Oops
DarkLight1337 Dec 31, 2024
ea93a2c
Fix unable to run model
DarkLight1337 Dec 31, 2024
9aeb7b2
Avoid warning from HF
DarkLight1337 Dec 31, 2024
768c1d9
fix too large image for fuyu
Isotr0py Dec 31, 2024
0c82c51
fix prompt token ids
Isotr0py Dec 31, 2024
d0d1fdc
Fix missing batch dimension in vision embeddings
DarkLight1337 Dec 31, 2024
afcf7b1
fix variant patches batching
Isotr0py Dec 31, 2024
cb9522d
Simplify the code
DarkLight1337 Dec 31, 2024
df832df
format
DarkLight1337 Dec 31, 2024
868e8e9
Merge branch 'vllm-project:main' into v1-rearch-vlm
ywang96 Dec 31, 2024
cc9c5f1
simplify
ywang96 Dec 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -570,28 +570,28 @@ See [this page](#generative-models) for more information on how to use generativ
- `rhymes-ai/Aria`
-
- ✅︎
-
- ✅︎
* - `Blip2ForConditionalGeneration`
- BLIP-2
- T + I<sup>E</sup>
- `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc.
-
- ✅︎
-
- ✅︎
* - `ChameleonForConditionalGeneration`
- Chameleon
- T + I
- `facebook/chameleon-7b` etc.
-
- ✅︎
-
- ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
- `adept/fuyu-8b` etc.
-
- ✅︎
-
- ✅︎
* - `ChatGLMModel`
- GLM-4V
- T + I
Expand Down Expand Up @@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
-
- ✅︎
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llava-next was already supported on V1 so this is just a doc update.

* - `LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- T + V
Expand Down
9 changes: 6 additions & 3 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down Expand Up @@ -191,8 +192,10 @@ def run_llava_next(question: str, modality: str):

prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
max_num_batched_tokens=32768,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 4},
enable_prefix_caching=False)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -591,7 +594,7 @@ def main(args):

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
sampling_params = SamplingParams(temperature=0,
max_tokens=64,
stop_token_ids=stop_token_ids)

Expand Down
115 changes: 89 additions & 26 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import math
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import torch
import torch.nn as nn
from PIL import Image
from torch.nn.init import trunc_normal_
from transformers import LlamaConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import INPUT_REGISTRY, token_inputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
SamplingMetadata)
from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
Expand All @@ -35,10 +38,12 @@
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
from vllm.utils import is_list_of

from .utils import flatten_bn

Expand Down Expand Up @@ -445,15 +450,74 @@ def build_mm_projector(config):
)


def get_max_multimodal_tokens(ctx):
return max(ctx.model_config.hf_config.image_size2tokens.values())


def input_mapper_for_aria(ctx, data):
return MultiModalKwargs(data)


def input_processor(ctx, llm_inputs):
def get_aria_max_multimodal_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
image_size2tokens = {
int(math.sqrt(k) * hf_config.vision_config.patch_size): v
for k, v in hf_config.projector_patch_to_query_dict.items()
}
return max(image_size2tokens.values())


def dummy_seq_data_for_aria(ctx: InputContext, seq_len: int, num_images: int):
image_feature_size = get_aria_max_multimodal_tokens(ctx)
hf_config = ctx.get_hf_config()
return SequenceData.from_prompt_token_counts(
(hf_config.image_token_index, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}


def dummy_image_for_aria(
ctx: InputContext,
num_images: int,
):
hf_config = ctx.get_hf_config()
max_image_size = hf_config.vision_config.image_size
image = Image.new("RGB", (max_image_size, max_image_size), color=0)
images = [image] * num_images

return {"image": images}


def dummy_data_for_aria(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_aria(ctx, seq_len, num_images)
mm_data = dummy_image_for_aria(ctx, num_images)
return DummyData(seq_data, mm_data, ranges)


def input_mapper_for_aria(ctx: InputContext, data: object):
data_list = data if isinstance(data, list) else [data]

# For profiling with dummy image data
if is_list_of(data_list, Image.Image):
hf_config = ctx.get_hf_config()
max_image_size = hf_config.vision_config.image_size
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code)
image_inputs = image_processor.preprocess(
data_list,
max_image_size=max_image_size,
split_image=False,
return_tensors="pt").data
image_inputs['pixel_values'] = image_inputs['pixel_values'].to(
ctx.model_config.dtype)
return MultiModalKwargs(image_inputs)

# For actual inference when image has been processed with
# prompt in input processor
return MultiModalKwargs(data_list[0])


def input_processor_for_aria(ctx: InputContext, llm_inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
# if it is pure text input, use it as is
if multi_modal_data is None or "image" not in multi_modal_data:
Expand Down Expand Up @@ -494,9 +558,12 @@ def input_processor(ctx, llm_inputs):
repeat_count=num_crops,
)

repeat_count = [hf_config.image_size2tokens[max_image_size]
] * sum(num_crops).item()
new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens(
image_size2tokens = {
int(math.sqrt(k) * hf_config.vision_config.patch_size): v
for k, v in hf_config.projector_patch_to_query_dict.items()
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this is a fixed value, perhaps we can move it to AriaMoELMConfig initialization in vllm/transformers_utils/configs/aria.py to avoid repeat calculation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I can do that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we actually don't need this calculation int(math.sqrt(k) * hf_config.vision_config.patch_size at all since we only care about the values here, so I will just simplify this.

repeat_count = [image_size2tokens[max_image_size]] * sum(num_crops).item()
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
None,
prompt_token_ids,
Expand All @@ -508,12 +575,14 @@ def input_processor(ctx, llm_inputs):
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data={"image": image_inputs},
multi_modal_placeholders={"image": ranges},
)


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_aria_max_multimodal_tokens)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
@INPUT_REGISTRY.register_input_processor(input_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_aria)
Copy link
Member Author

@ywang96 ywang96 Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code for dummy data generation was entirely missing and I'm not sure why, so I added in this PR since it's required for V1. cc @xffxff who originally added this model

@INPUT_REGISTRY.register_input_processor(input_processor_for_aria)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Aria model for conditional generation tasks.
Expand All @@ -540,12 +609,6 @@ def __init__(
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

# prepare the image_size to tokens mapping for the image preprocess, see
# input_processor
config.image_size2tokens = {
int(math.sqrt(k) * config.vision_config.patch_size): v
for k, v in config.projector_patch_to_query_dict.items()
}
self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
Expand All @@ -566,7 +629,7 @@ def __init__(
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = Sampler()
self.sampler = get_sampler()

def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

Expand Down Expand Up @@ -468,14 +468,18 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += inputs["prompt_token_ids"]
placeholder_ranges = [
PlaceholderRange(offset=0, length=image_feature_size)
]

new_prompt = inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt

return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})


@MULTIMODAL_REGISTRY.register_image_input_mapper()
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def input_processor_for_chameleon(ctx: InputContext,
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})


class ChameleonLayerNorm(nn.LayerNorm):
Expand Down
Loading
Loading