From dc4885752bf2a85fd4de893a4eef7c6e554447f7 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 28 Oct 2024 15:51:34 +0000 Subject: [PATCH 01/19] Init Signed-off-by: Jee Jee Li --- vllm/model_executor/models/idefics3.py | 492 +++++++++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 vllm/model_executor/models/idefics3.py diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py new file mode 100644 index 0000000000000..ac62a5667862d --- /dev/null +++ b/vllm/model_executor/models/idefics3.py @@ -0,0 +1,492 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Idefics3 model compatible with HuggingFace weights.""" + +from typing import ( + Iterable, + List, + Literal, + Mapping, + Optional, + Tuple, + TypedDict, + Union, +) + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers import Idefics3Config + +from vllm.logger import init_logger +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput + +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + + +from .llama import LlamaModel +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer, ) +from .interfaces import SupportsMultiModal +from .siglip import ( + SiglipVisionModel, + dummy_image_for_siglip, + dummy_seq_data_for_siglip, + get_max_siglip_image_tokens, + input_processor_for_siglip, +) +from .utils import ( + AutoWeightsLoader, + flatten_bn, + merge_multimodal_embeddings, +) + +logger = init_logger(__name__) + + +class ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +ImageInputs = Union[ImagePixelInputs, ImageEmbeddingInputs] + + +def input_processor_for_idefics3(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + model_config = ctx.model_config + version = get_version_by_config(model_config.hf_config) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + image_processor = cached_get_image_processor(model_config.tokenizer) + + def get_placeholder(image_size: Tuple[int, int], num_image: int): + if version == (2, 0) or version == (2, 5): + return image_processor. \ + get_slice_image_placeholder(image_size) + return image_processor. \ + get_slice_image_placeholder(image_size, num_image) + + prompt = llm_inputs.get("prompt") + token_ids = llm_inputs.get("prompt_token_ids") + if prompt is None: + prompt = tokenizer.decode(token_ids) + + pattern = "(./)" + images = multi_modal_data["image"] + image_tags = re.findall(pattern, prompt) + if len(image_tags) == 0: + new_token_ids = token_ids + new_prompt = prompt + else: + if isinstance(images, dict): + image_size_list = images.get("image_size_list") + images = [images.get("image_embeds")] + else: + if isinstance(images, Image.Image): + images = [images] + image_size_list = [image.size for image in images] + + text_chunks = prompt.split(pattern) + new_prompt_chunks: List[str] = [] + for i in range(len(image_size_list)): + new_prompt_chunks += [ + text_chunks[i], + get_placeholder(image_size_list[i], i) + ] + new_prompt_chunks.append(text_chunks[-1]) + new_prompt = "".join(new_prompt_chunks) + new_token_ids = tokenizer.encode(new_prompt) + + multi_modal_data["image"] = [ + _build_image_input(ctx, image) for image in images + ] + + llm_inputs = LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + return llm_inputs + + +class Idefics3SimpleMLP(nn.Module): + + def __init__(self, config): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor** + 2) + output_size = config.text_config.hidden_size + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3 +class Idefics3RMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + Idefics3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Idefics3Connector(nn.Module): + + def __init__(self, config): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projection = Idefics3SimpleMLP(config) + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), + embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), + embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, + self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3Model(nn.Module): + + def __init__( + self, + config: Idefics3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.padding_idx = self.config.text_config.pad_token_id + self.vocab_size = self.config.text_config.vocab_size + + self.vision_model = Idefics3VisionTransformer(config.vision_config, ) + self.connector = Idefics3Connector(config) + self.text_model = LlamaModel(config.text_config, cache_config, + quant_config) + + self.image_seq_len = int( + ((config.vision_config.image_size // + config.vision_config.patch_size)**2) / (config.scale_factor**2)) + self.image_token_id = self.config.image_token_id + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, inputs: ImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + return self.multi_modal_projector(image_features) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + r""" + TODO + ```""" + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + input_ids = None + + hidden_states = self.text_model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + +def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(Idefics3Config) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + image_feature_size = get_max_llava_image_tokens(ctx) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_clip(vision_config, num_images) + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip(vision_config, num_images) + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens() +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) +@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: Idefics3Config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + self.model = Idefics3Model(config, cache_config, quant_config) + self.image_token_id = self.config.image_token_id + + self.lm_head = nn.Linear( + config.text_config.hidden_size, + config.text_config.vocab_size, + bias=False, + ) + self.vocab_size = config.text_config.vocab_size + + self.logits_processor = LogitsProcessor(self.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> SamplerOutput: + r""" + TODO + ```""" + + outputs = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + **kwargs, + ) + + hidden_states = outputs[0] + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) From e2e9811871a02dc125abf6d54b7fcfd2de82df3b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 29 Oct 2024 03:41:57 +0000 Subject: [PATCH 02/19] Modify model code Signed-off-by: Jee Jee Li --- vllm/model_executor/models/idefics3.py | 218 +++++-------------------- vllm/model_executor/models/registry.py | 1 + 2 files changed, 38 insertions(+), 181 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index ac62a5667862d..4e0ddbf84bbf8 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -15,6 +15,8 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" from typing import ( + Any, + Callable, Iterable, List, Literal, @@ -29,7 +31,7 @@ import torch.utils.checkpoint from torch import nn -from transformers import Idefics3Config +from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.attention import AttentionMetadata @@ -37,25 +39,19 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput - +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors - from .llama import LlamaModel from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer, ) from .interfaces import SupportsMultiModal -from .siglip import ( - SiglipVisionModel, - dummy_image_for_siglip, - dummy_seq_data_for_siglip, - get_max_siglip_image_tokens, - input_processor_for_siglip, -) + from .utils import ( AutoWeightsLoader, flatten_bn, @@ -83,67 +79,6 @@ class ImageEmbeddingInputs(TypedDict): ImageInputs = Union[ImagePixelInputs, ImageEmbeddingInputs] -def input_processor_for_idefics3(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs - model_config = ctx.model_config - version = get_version_by_config(model_config.hf_config) - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - image_processor = cached_get_image_processor(model_config.tokenizer) - - def get_placeholder(image_size: Tuple[int, int], num_image: int): - if version == (2, 0) or version == (2, 5): - return image_processor. \ - get_slice_image_placeholder(image_size) - return image_processor. \ - get_slice_image_placeholder(image_size, num_image) - - prompt = llm_inputs.get("prompt") - token_ids = llm_inputs.get("prompt_token_ids") - if prompt is None: - prompt = tokenizer.decode(token_ids) - - pattern = "(./)" - images = multi_modal_data["image"] - image_tags = re.findall(pattern, prompt) - if len(image_tags) == 0: - new_token_ids = token_ids - new_prompt = prompt - else: - if isinstance(images, dict): - image_size_list = images.get("image_size_list") - images = [images.get("image_embeds")] - else: - if isinstance(images, Image.Image): - images = [images] - image_size_list = [image.size for image in images] - - text_chunks = prompt.split(pattern) - new_prompt_chunks: List[str] = [] - for i in range(len(image_size_list)): - new_prompt_chunks += [ - text_chunks[i], - get_placeholder(image_size_list[i], i) - ] - new_prompt_chunks.append(text_chunks[-1]) - new_prompt = "".join(new_prompt_chunks) - new_token_ids = tokenizer.encode(new_prompt) - - multi_modal_data["image"] = [ - _build_image_input(ctx, image) for image in images - ] - - llm_inputs = LLMInputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - ) - return llm_inputs - - class Idefics3SimpleMLP(nn.Module): def __init__(self, config): @@ -151,50 +86,11 @@ def __init__(self, config): input_size = config.vision_config.hidden_size * (config.scale_factor** 2) output_size = config.text_config.hidden_size - self.proj = nn.Linear(input_size, output_size, bias=False) - - def forward(self, x): - return self.proj(x) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, - num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3 -class Idefics3RMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - Idefics3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + self.proj = ReplicatedLinear(input_size, output_size, bias=False) - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + def forward(self, x) -> torch.Tensor: + out, _ = self.proj(x) + return out class Idefics3Connector(nn.Module): @@ -233,7 +129,7 @@ class Idefics3Model(nn.Module): def __init__( self, - config: Idefics3Config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -242,7 +138,8 @@ def __init__( self.padding_idx = self.config.text_config.pad_token_id self.vocab_size = self.config.text_config.vocab_size - self.vision_model = Idefics3VisionTransformer(config.vision_config, ) + self.vision_model = Idefics3VisionTransformer(config.vision_config, + quant_config) self.connector = Idefics3Connector(config) self.text_model = LlamaModel(config.text_config, cache_config, quant_config) @@ -308,10 +205,9 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features( self, - vision_tower: Union[SiglipVisionModel], + vision_tower: Idefics3VisionTransformer, pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_tower(pixel_values) @@ -329,7 +225,6 @@ def _process_image_pixels(self, inputs: ImagePixelInputs) -> torch.Tensor: return self._image_pixels_to_features(self.vision_tower, pixel_values) def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: - if image_input["type"] == "image_embeds": return image_input["data"] @@ -363,68 +258,36 @@ def forward( input_ids) inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + input_ids, + inputs_embeds, + vision_embeddings, + self.config.image_token_index, + ) else: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) input_ids = None - hidden_states = self.text_model(input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds=inputs_embeds) - - return hidden_states - - -def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(Idefics3Config) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - image_feature_size = get_max_llava_image_tokens(ctx) - - if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_clip(vision_config, num_images) - return seq_data, mm_data - elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + hidden_states = self.text_model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds, ) - - mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return hidden_states @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens() -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) -@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) +@INPUT_REGISTRY.register_dummy_data() +@INPUT_REGISTRY.register_input_processor() class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): - _tied_weights_keys = ["lm_head.weight"] def __init__( self, - config: Idefics3Config, + config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -437,14 +300,14 @@ def __init__( self.model = Idefics3Model(config, cache_config, quant_config) self.image_token_id = self.config.image_token_id - self.lm_head = nn.Linear( + self.lm_head = ParallelLMHead( config.text_config.hidden_size, config.text_config.vocab_size, - bias=False, + quant_config=quant_config, ) - self.vocab_size = config.text_config.vocab_size - - self.logits_processor = LogitsProcessor(self.vocab_size) + if self.config.text_config.tie_word_embeddings: + self.lm_head.weight = self.model.text_model.wte.weight + self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.sampler = Sampler() def forward( @@ -455,12 +318,8 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: - r""" - TODO - ```""" - - outputs = self.model( + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( input_ids, positions, kv_caches, @@ -468,9 +327,6 @@ def forward( intermediate_tensors, **kwargs, ) - - hidden_states = outputs[0] - return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 595a9256f958e..6520eff181d26 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -111,6 +111,7 @@ "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "InternVLChatModel": ("internvl", "InternVLChatModel"), + "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 From 120047540ddd397cafdcaafff991294799750469 Mon Sep 17 00:00:00 2001 From: B-201 Date: Thu, 31 Oct 2024 09:41:44 +0800 Subject: [PATCH 03/19] Update code Signed-off-by: B-201 --- examples/offline_inference_vision_language.py | 17 + .../models/idefics2_vision_model.py | 25 +- vllm/model_executor/models/idefics3.py | 439 +++++++++++++++--- 3 files changed, 419 insertions(+), 62 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 83d2548a506e4..2c22c796189b9 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -353,6 +353,22 @@ def run_glm4v(question: str, modality: str): return llm, prompt, stop_token_ids +# Idefics3-8B-Llama3 +def run_idefics3(question: str, modality: str): + assert modality == "image" + model_name = ("HuggingFaceM4/Idefics3-8B-Llama3") + + llm = LLM(model=model_name, + max_model_len=2048, + max_num_seqs=2, + enforce_eager=True) + prompt = ( + f"<|begin_of_text|>User:{question}\nAssistant:" + ) + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -372,6 +388,7 @@ def run_glm4v(question: str, modality: str): "mllama": run_mllama, "molmo": run_molmo, "glm4v": run_glm4v, + "idefics3": run_idefics3, } diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 43f4f29814e6d..58be828c9e405 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -17,7 +17,7 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Optional +from typing import Optional, Iterable, Tuple import torch from torch import nn @@ -26,6 +26,7 @@ from xformers import ops as xops from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -331,3 +332,25 @@ def forward( encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 4e0ddbf84bbf8..e0df74534a5e1 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -15,8 +15,7 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" from typing import ( - Any, - Callable, + Dict, Iterable, List, Literal, @@ -27,16 +26,23 @@ Union, ) +import math import torch import torch.utils.checkpoint from torch import nn -from transformers import PretrainedConfig +from transformers import Idefics3Config +from PIL import Image from vllm.logger import init_logger from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import ( + INPUT_REGISTRY, + InputContext, + DecoderOnlyInputs, + token_inputs, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.linear import ReplicatedLinear @@ -44,8 +50,11 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.transformers_utils.processor import cached_get_processor +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.utils import is_list_of from .llama import LlamaModel from .idefics2_vision_model import ( @@ -55,19 +64,21 @@ from .utils import ( AutoWeightsLoader, flatten_bn, - merge_multimodal_embeddings, ) logger = init_logger(__name__) -class ImagePixelInputs(TypedDict): +class Idefics3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """Shape: `(batch_size * num_images, num_channels, height, width)`""" + rows: List[int] + cols: List[int] + pixel_attention_mask: Optional[torch.BoolTensor] -class ImageEmbeddingInputs(TypedDict): +class Idefics3ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` @@ -76,7 +87,249 @@ class ImageEmbeddingInputs(TypedDict): """ -ImageInputs = Union[ImagePixelInputs, ImageEmbeddingInputs] +ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] + + +def input_mapper_for_idefics3( + ctx: InputContext, + data: object, +): + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + if isinstance(data, list): + images = [data] + else: + raise ValueError() + + try: + batch_data = image_processor(images, + return_tensors="pt", + return_row_col_info=True).data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + + return MultiModalInputs(batch_data) + + +def _resize_output_size_rescale_to_max_len( + height: int, + width: int, + min_len: Optional[int] = 1, + max_len: Optional[int] = None) -> Tuple[int, int]: + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + if width >= height: + width = max_len + height = int(width / aspect_ratio) + if height % 2 != 0: + height += 1 + elif height > width: + height = max_len + width = int(height * aspect_ratio) + if width % 2 != 0: + width += 1 + + # Avoid resizing to a size smaller than min_len + height = max(height, min_len) + width = max(width, min_len) + return height, width + + +def _resize_output_size_scale_below_upper_bound( + height: int, + width: int, + max_len: Optional[Dict[str, int]] = None) -> Tuple[int, int]: + max_len = max(height, width) if max_len is None else max_len + + aspect_ratio = width / height + if width >= height and width > max_len: + width = max_len + height = int(width / aspect_ratio) + elif height > width and height > max_len: + height = max_len + width = int(height * aspect_ratio) + + # Avoid resizing to a size smaller than 1 + height = max(height, 1) + width = max(width, 1) + return height, width + + +def _get_resize_output_image_size( + image_size, + resolution_max_side: int, + max_image_size: int = 1820, +) -> Tuple[int, int]: + if resolution_max_side > max_image_size: + raise ValueError( + "`resolution_max_side` cannot be larger than `max_image_size`") + + height, width = image_size + + # Find the output size, when rescaling the longest edge to max_len and + # preserving the aspect ratio + height, width = _resize_output_size_rescale_to_max_len( + height, width, max_len=resolution_max_side) + # Find the output size when scaling the image to be below the max_image_size + height, width = _resize_output_size_scale_below_upper_bound( + height, width, max_len=max_image_size) + return height, width + + +def _prompt_split_image(image_seq_len, image_rows, image_cols, + fake_token_around_image, image_token, + global_img_token): + """ + Prompt with expanded image tokens for when the image is split + into patches. + """ + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += (f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len) + text_split_images += "\n" + + text_split_images += (f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}") + return text_split_images + + +def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, + global_img_token): + """Prompt with expanded image tokens for a single image.""" + return (f"{fake_token_around_image}" + f"{global_img_token}" + + f"{image_token}" * image_seq_len + f"{fake_token_around_image}") + + +def _get_image_prompt_string(image_rows, image_cols, image_seq_len, + fake_token_around_image, image_token, + global_img_token): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image(image_seq_len, image_rows, image_cols, + fake_token_around_image, image_token, + global_img_token) + + +def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + model_config = ctx.model_config + processor = cached_get_processor(model_config.model) + image_processor = processor.image_processor + tokenizer = processor.tokenizer + size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size['longest_edge'] + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_list = [image_data] + elif is_list_of(image_data, Image.Image): + image_list = image_data + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + image_rows = [] + image_cols = [] + for image in image_list: + height, width = _get_resize_output_image_size(image.size, size) + + rows = math.ceil(height / max_image_size) + cols = math.ceil(width / max_image_size) + image_rows.append(rows) + image_cols.append(cols) + image_rows = [image_rows] + image_cols = [image_cols] + + n_images_in_text = [] + + text = inputs.get("prompt") + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + + fake_image_token = processor.fake_image_token.content + image_token = processor.image_token.content + global_img_token = processor.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, + image_cols): + n_images_in_text.append(sample.count(image_token)) + + # Replace the image token with fake tokens around the expanded + # image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = _get_image_prompt_string( + n_rows, + n_cols, + processor.image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) + + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError( + "The image token should be present in the text.") + + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) + + prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt_strings[0], + multi_modal_data=multi_modal_data, + ) + + +def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config() + num_images = mm_counts["image"] + + processor = cached_get_processor(ctx.model_config.model) + image_seq_len = processor.image_seq_len + max_llm_image_tokens = 17 * image_seq_len + + seq_data = SequenceData.from_prompt_token_counts( + (hf_config.image_token_id, max_llm_image_tokens), (0, seq_len)) + + width = height = hf_config.vision_config.image_size + image = Image.new("RGB", (width, height), color=0) + mm_data = {"image": [image] if num_images == 1 else [image] * num_images} + + return seq_data, mm_data class Idefics3SimpleMLP(nn.Module): @@ -129,7 +382,7 @@ class Idefics3Model(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Idefics3Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -149,16 +402,24 @@ def __init__( config.vision_config.patch_size)**2) / (config.scale_factor**2)) self.image_token_id = self.config.image_token_id - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + h = w = self.config.vision_config.image_size expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data @@ -166,71 +427,130 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) + rows = kwargs.pop("rows", None) + cols = kwargs.pop("cols", None) + pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - return ImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), - ) - if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - return ImageEmbeddingInputs( + return Idefics3ImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds, concat=True), ) - raise AssertionError("This line should be unreachable.") + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features + return Idefics3ImagePixelInputs(type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, + concat=True)), + rows=rows, + cols=cols, + pixel_attention_mask=flatten_bn( + pixel_attention_mask, + concat=True)) - raise ValueError(f"Unexpected select feature strategy: {strategy}") + raise AssertionError("This line should be unreachable.") def _image_pixels_to_features( self, - vision_tower: Idefics3VisionTransformer, + vision_model: Idefics3VisionTransformer, pixel_values: torch.Tensor, + pixel_attention_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to( + dtype=self.vision_model.embeddings.patch_embedding.weight.dtype + ) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, + *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), + pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, + size=patch_size, + step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, + size=patch_size, + step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, ) - def _process_image_pixels(self, inputs: ImagePixelInputs) -> torch.Tensor: - assert self.vision_tower is not None + return image_hidden_states + + def _process_image_pixels( + self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None pixel_values = inputs["data"] + pixel_attention_mask = inputs["pixel_attention_mask"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) + return self._image_pixels_to_features(self.vision_model, pixel_values, + pixel_attention_mask) def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] - assert self.vision_tower is not None + assert self.vision_model is not None image_features = self._process_image_pixels(image_input) - return self.multi_modal_projector(image_features) + return self.connector(image_features) + + def _merge_multimodal_embeddings( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.Tensor], + vision_embeddings: Optional[torch.Tensor], + ): + num_images, _, vision_hidden_size = vision_embeddings.shape + special_image_token_mask = input_ids == self.image_token_id + # Fixes RuntimeError: a leaf Variable that requires grad is being used + # in an in-place operation. + new_inputs_embeds = inputs_embeds.clone() + reshaped_vision_embeddings = vision_embeddings.view( + -1, vision_hidden_size) + # cast to the dtype of the input_embeds to support quantized models + reshaped_vision_embeddings = reshaped_vision_embeddings.to( + inputs_embeds.dtype) + new_inputs_embeds[ + special_image_token_mask] = reshaped_vision_embeddings + return new_inputs_embeds def forward( self, @@ -254,18 +574,15 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + inputs_embeds = self.text_model.get_input_embeddings(input_ids) - inputs_embeds = merge_multimodal_embeddings( + inputs_embeds = self._merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index, ) else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + inputs_embeds = self.text_model.get_input_embeddings(input_ids) input_ids = None hidden_states = self.text_model( @@ -279,15 +596,15 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3) @MULTIMODAL_REGISTRY.register_max_image_tokens() -@INPUT_REGISTRY.register_dummy_data() -@INPUT_REGISTRY.register_input_processor() +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) +@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__( self, - config: PretrainedConfig, + config: Idefics3Config, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -301,8 +618,8 @@ def __init__( self.image_token_id = self.config.image_token_id self.lm_head = ParallelLMHead( - config.text_config.hidden_size, config.text_config.vocab_size, + config.text_config.hidden_size, quant_config=quant_config, ) if self.config.text_config.tie_word_embeddings: From 6cd4bfac2f22c9b9049d960931ebdfc9f6bba87a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 31 Oct 2024 05:19:30 +0000 Subject: [PATCH 04/19] Fix code format --- .../models/idefics2_vision_model.py | 4 +- vllm/model_executor/models/idefics3.py | 61 ++++++------------- 2 files changed, 21 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 58be828c9e405..9132ba41dbda2 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -17,7 +17,7 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Optional, Iterable, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn @@ -26,12 +26,12 @@ from xformers import ops as xops from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader class Idefics2VisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index e0df74534a5e1..769793f28c5de 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -14,57 +14,38 @@ # limitations under the License. """Inference-only Idefics3 model compatible with HuggingFace weights.""" -from typing import ( - Dict, - Iterable, - List, - Literal, - Mapping, - Optional, - Tuple, - TypedDict, - Union, -) - import math +from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + import torch import torch.utils.checkpoint +from PIL import Image from torch import nn - from transformers import Idefics3Config -from PIL import Image -from vllm.logger import init_logger from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import ( - INPUT_REGISTRY, - InputContext, - DecoderOnlyInputs, - token_inputs, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor - +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal.image import cached_get_image_processor -from vllm.transformers_utils.processor import cached_get_processor from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of -from .llama import LlamaModel from .idefics2_vision_model import ( - Idefics2VisionTransformer as Idefics3VisionTransformer, ) + Idefics2VisionTransformer as Idefics3VisionTransformer) from .interfaces import SupportsMultiModal - -from .utils import ( - AutoWeightsLoader, - flatten_bn, -) +from .llama import LlamaModel +from .utils import AutoWeightsLoader, flatten_bn logger = init_logger(__name__) @@ -72,7 +53,9 @@ class Idefics3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + """ rows: List[int] cols: List[int] pixel_attention_mask: Optional[torch.BoolTensor] @@ -81,8 +64,8 @@ class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -266,9 +249,8 @@ def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs): if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError( - "Invalid input text. Please provide a string, or a list of strings" - ) + raise ValueError("Invalid input text. Please provide a string, " + "or a list of strings") fake_image_token = processor.fake_image_token.content image_token = processor.image_token.content @@ -540,8 +522,6 @@ def _merge_multimodal_embeddings( ): num_images, _, vision_hidden_size = vision_embeddings.shape special_image_token_mask = input_ids == self.image_token_id - # Fixes RuntimeError: a leaf Variable that requires grad is being used - # in an in-place operation. new_inputs_embeds = inputs_embeds.clone() reshaped_vision_embeddings = vision_embeddings.view( -1, vision_hidden_size) @@ -561,9 +541,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - r""" - TODO - ```""" if intermediate_tensors is not None: input_ids = None inputs_embeds = None From 9cb4e3298a5596058d0db979c7027ac28d546cf7 Mon Sep 17 00:00:00 2001 From: B-201 Date: Thu, 31 Oct 2024 15:38:19 +0800 Subject: [PATCH 05/19] Update code Signed-off-by: B-201 --- examples/offline_inference_vision_language.py | 10 ++++--- vllm/model_executor/models/idefics3.py | 26 ++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 2c22c796189b9..64c016e56337b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,6 +5,8 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '7' from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -356,10 +358,10 @@ def run_glm4v(question: str, modality: str): # Idefics3-8B-Llama3 def run_idefics3(question: str, modality: str): assert modality == "image" - model_name = ("HuggingFaceM4/Idefics3-8B-Llama3") + model_name = ("/home/sobey/Models/llm_models/BaseModel/idefics/Idefics3-8B-Llama3") llm = LLM(model=model_name, - max_model_len=2048, + max_model_len=8192, max_num_seqs=2, enforce_eager=True) prompt = ( @@ -476,12 +478,12 @@ def main(args): parser.add_argument('--model-type', '-m', type=str, - default="llava", + default="idefics3", choices=model_example_map.keys(), help='Huggingface "model_type".') parser.add_argument('--num-prompts', type=int, - default=4, + default=1, help='Number of prompts to run.') parser.add_argument('--modality', type=str, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 769793f28c5de..1e40a17f51d71 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -84,10 +84,12 @@ def input_mapper_for_idefics3( raise RuntimeError("No HuggingFace processor is available " "to process the image object") - if isinstance(data, list): + if isinstance(data, Image.Image): + images = [[data]] + elif is_list_of(data, Image.Image): images = [data] else: - raise ValueError() + raise TypeError(f"Invalid image type: {type(data)}") try: batch_data = image_processor(images, @@ -295,6 +297,24 @@ def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs): ) +def get_max_idefics3_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): + model_config = ctx.model_config + processor = cached_get_processor(model_config.model) + image_seq_len = processor.image_seq_len + image_processor = processor.image_processor + + size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size['longest_edge'] + resized_height, resized_width = size, size + + grid_h = resized_height // max_image_size + grid_w = resized_width // max_image_size + + return (grid_h * grid_w + 1) * image_seq_len + + def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config() @@ -574,7 +594,7 @@ def forward( @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3) -@MULTIMODAL_REGISTRY.register_max_image_tokens() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) @INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): From 6620b7ce190dcfe91773c934d23c075c6f9d61a5 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 31 Oct 2024 07:54:09 +0000 Subject: [PATCH 06/19] Delete dirty code Signed-off-by: Jee Jee Li --- examples/offline_inference_vision_language.py | 4 +--- vllm/model_executor/models/idefics3.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 64c016e56337b..24162e98f5155 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,8 +5,6 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '7' from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -358,7 +356,7 @@ def run_glm4v(question: str, modality: str): # Idefics3-8B-Llama3 def run_idefics3(question: str, modality: str): assert modality == "image" - model_name = ("/home/sobey/Models/llm_models/BaseModel/idefics/Idefics3-8B-Llama3") + model_name = "HuggingFaceM4/Idefics3-8B-Llama3" llm = LLM(model=model_name, max_model_len=8192, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 1e40a17f51d71..df8efda85a87d 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -298,8 +298,8 @@ def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs): def get_max_idefics3_image_tokens(ctx: InputContext, - *, - num_crops: Optional[int] = None): + *, + num_crops: Optional[int] = None): model_config = ctx.model_config processor = cached_get_processor(model_config.model) image_seq_len = processor.image_seq_len From 81168cfaed0be52f764686281539520dc7c97860 Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 4 Nov 2024 00:01:43 +0800 Subject: [PATCH 07/19] Support multi-image Signed-off-by: B-201 --- examples/offline_inference_vision_language.py | 4 ++-- vllm/model_executor/models/idefics3.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 87a9c34f80411..ac33fe5a842e3 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -475,12 +475,12 @@ def main(args): parser.add_argument('--model-type', '-m', type=str, - default="idefics3", + default="llava", choices=model_example_map.keys(), help='Huggingface "model_type".') parser.add_argument('--num-prompts', type=int, - default=1, + default=4, help='Number of prompts to run.') parser.add_argument('--modality', type=str, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index df8efda85a87d..2c8c1e8bc8d76 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -322,7 +322,7 @@ def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, processor = cached_get_processor(ctx.model_config.model) image_seq_len = processor.image_seq_len - max_llm_image_tokens = 17 * image_seq_len + max_llm_image_tokens = 17 * image_seq_len * num_images seq_data = SequenceData.from_prompt_token_counts( (hf_config.image_token_id, max_llm_image_tokens), (0, seq_len)) From b35356172b1708fc4d703e3fcfb538d2c6c7244d Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 4 Nov 2024 11:32:59 +0800 Subject: [PATCH 08/19] Add unit test Signed-off-by: B-201 --- .../vision_language/test_idefics3.py | 197 ++++++++++++++++++ vllm/model_executor/models/idefics3.py | 6 +- 2 files changed, 200 insertions(+), 3 deletions(-) create mode 100644 tests/models/decoder_only/vision_language/test_idefics3.py diff --git a/tests/models/decoder_only/vision_language/test_idefics3.py b/tests/models/decoder_only/vision_language/test_idefics3.py new file mode 100644 index 0000000000000..d153110445bd1 --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_idefics3.py @@ -0,0 +1,197 @@ +from typing import List, Optional, Tuple, Type + +import pytest +from transformers import AutoTokenizer + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from utils import check_logprobs_close + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|begin_of_text|>User:What's the content of the image?\nAssistant:", # noqa: E501 + "cherry_blossom": + "<|begin_of_text|>User:What is the season?\nAssistant:", # noqa: E501 +}) +HF_MULTIIMAGE_IMAGE_PROMPT = "<|begin_of_text|>User:Describe these images.\nAssistant:" # noqa: E501 + +models = ["HuggingFaceM4/Idefics3-8B-Llama3"] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + _, output_str, out_logprobs = vllm_output + + output_str_without_image = output_str + + hf_output_str = output_str_without_image + "<|end_of_text|>" # noqa: E501 + + tokenizer = AutoTokenizer.from_pretrained(model) + hf_output_ids = tokenizer.encode(output_str_without_image) + assert hf_output_ids[0] == 128000 + hf_output_ids = hf_output_ids[1:] + + return hf_output_ids, hf_output_str, out_logprobs + + +target_dtype = "half" +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # HACK - this is an attempted workaround for the following bug + # https://github.com/huggingface/transformers/issues/34307 + from transformers import AutoModelForVision2Seq # noqa: F401 + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + task="generate", + max_model_len=8192, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + # use eager mode for hf runner, since phi3_v didn't work with flash_attn + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model, dtype=dtype, + model_kwargs=hf_model_kwargs, + auto_cls=AutoModelForVision2Seq) as hf_model: + eos_token_id = hf_model.processor.tokenizer.eos_token_id + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + eos_token_id=eos_token_id) + for prompts, images in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +# Since we use _attn_implementation="eager" for hf_runner, there is more +# significant numerical difference. The basic `logprobs=5` fails to pass. +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + run_test( + hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + run_test( + hf_runner, + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=2, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 2c8c1e8bc8d76..6ed46e560dc57 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -26,8 +26,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -331,7 +331,7 @@ def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, image = Image.new("RGB", (width, height), color=0) mm_data = {"image": [image] if num_images == 1 else [image] * num_images} - return seq_data, mm_data + return DummyData(seq_data, mm_data) class Idefics3SimpleMLP(nn.Module): From 9e1d3cf9bbe04cf3d6ac315ce738f65cb5863d4c Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 4 Nov 2024 11:39:18 +0800 Subject: [PATCH 09/19] Disable yapf Signed-off-by: B-201 --- vllm/model_executor/models/idefics3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 6ed46e560dc57..13cf50a56545e 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -41,8 +41,10 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of +# yapf: disable from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable from .interfaces import SupportsMultiModal from .llama import LlamaModel from .utils import AutoWeightsLoader, flatten_bn From 835d9ba7dcd3e90560a97045ebc8cc98c362de50 Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 4 Nov 2024 12:56:53 +0800 Subject: [PATCH 10/19] Fix code format Signed-off-by: B-201 --- .../decoder_only/vision_language/test_idefics3.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_idefics3.py b/tests/models/decoder_only/vision_language/test_idefics3.py index d153110445bd1..1dae8bf3ab5eb 100644 --- a/tests/models/decoder_only/vision_language/test_idefics3.py +++ b/tests/models/decoder_only/vision_language/test_idefics3.py @@ -6,18 +6,19 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from utils import check_logprobs_close +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "<|begin_of_text|>User:What's the content of the image?\nAssistant:", # noqa: E501 "cherry_blossom": - "<|begin_of_text|>User:What is the season?\nAssistant:", # noqa: E501 + "<|begin_of_text|>User:What is the season?\nAssistant:", # noqa: E501 }) HF_MULTIIMAGE_IMAGE_PROMPT = "<|begin_of_text|>User:Describe these images.\nAssistant:" # noqa: E501 models = ["HuggingFaceM4/Idefics3-8B-Llama3"] +target_dtype = "half" def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -38,7 +39,6 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs -target_dtype = "half" def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -89,8 +89,9 @@ def run_test( # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs, + with hf_runner(model, + dtype=dtype, + model_kwargs=hf_model_kwargs, auto_cls=AutoModelForVision2Seq) as hf_model: eos_token_id = hf_model.processor.tokenizer.eos_token_id hf_outputs_per_case = [ From 469a96ea696918d507d0dcc1d02271af52043cf1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 5 Nov 2024 02:34:48 +0000 Subject: [PATCH 11/19] Fix format Signed-off-by: Jee Jee Li --- .../decoder_only/vision_language/test_idefics3.py | 6 +----- vllm/model_executor/models/idefics3.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_idefics3.py b/tests/models/decoder_only/vision_language/test_idefics3.py index 1dae8bf3ab5eb..669a806a16233 100644 --- a/tests/models/decoder_only/vision_language/test_idefics3.py +++ b/tests/models/decoder_only/vision_language/test_idefics3.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +from transformers import AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -61,10 +61,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - # HACK - this is an attempted workaround for the following bug - # https://github.com/huggingface/transformers/issues/34307 - from transformers import AutoModelForVision2Seq # noqa: F401 - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 13cf50a56545e..cc2739de5a12f 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -318,7 +318,7 @@ def get_max_idefics3_image_tokens(ctx: InputContext, def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): + mm_counts: Mapping[str, int]) -> DummyData: hf_config = ctx.get_hf_config() num_images = mm_counts["image"] @@ -345,7 +345,7 @@ def __init__(self, config): output_size = config.text_config.hidden_size self.proj = ReplicatedLinear(input_size, output_size, bias=False) - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: out, _ = self.proj(x) return out @@ -357,7 +357,9 @@ def __init__(self, config): self.scale_factor = config.scale_factor self.modality_projection = Idefics3SimpleMLP(config) - def pixel_shuffle(self, x, scale_factor=2): + def pixel_shuffle(self, + x: torch.Tensor, + scale_factor: int = 2) -> torch.Tensor: bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) @@ -375,7 +377,7 @@ def pixel_shuffle(self, x, scale_factor=2): embed_dim * (scale_factor**2)) return x - def forward(self, image_hidden_states): + def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) @@ -467,7 +469,6 @@ def _parse_and_validate_image_input( def _image_pixels_to_features( self, - vision_model: Idefics3VisionTransformer, pixel_values: torch.Tensor, pixel_attention_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: @@ -511,7 +512,7 @@ def _image_pixels_to_features( patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() # Get sequence from the vision encoder - image_hidden_states = vision_model( + image_hidden_states = self.vision_model( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, ) @@ -525,7 +526,7 @@ def _process_image_pixels( pixel_values = inputs["data"] pixel_attention_mask = inputs["pixel_attention_mask"] - return self._image_pixels_to_features(self.vision_model, pixel_values, + return self._image_pixels_to_features(pixel_values, pixel_attention_mask) def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: From 3285b9166f93d87877772dbe4a292fcee9cf20e6 Mon Sep 17 00:00:00 2001 From: B-201 Date: Tue, 5 Nov 2024 12:58:50 +0800 Subject: [PATCH 12/19] Fix code format Signed-off-by: B-201 --- vllm/model_executor/models/idefics3.py | 76 +++++++++++--------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index cc2739de5a12f..4b2f2dda6cbe8 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -15,7 +15,7 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" import math -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) import torch @@ -104,53 +104,40 @@ def input_mapper_for_idefics3( return MultiModalInputs(batch_data) -def _resize_output_size_rescale_to_max_len( - height: int, - width: int, - min_len: Optional[int] = 1, - max_len: Optional[int] = None) -> Tuple[int, int]: +def _resize_output_size(height: int, + width: int, + max_len: Optional[int] = None, + min_len: Optional[int] = 1, + max_size: Optional[int] = None) -> Tuple[int, int]: + # Set default value for max_len if not provided max_len = max(height, width) if max_len is None else max_len aspect_ratio = width / height + # Handle the maximum size constraint + if max_size is not None: + max_len = min(max_len, max_size) + + # Adjust dimensions according to the aspect ratio if width >= height: width = max_len height = int(width / aspect_ratio) - if height % 2 != 0: - height += 1 - elif height > width: + else: height = max_len width = int(height * aspect_ratio) - if width % 2 != 0: - width += 1 - # Avoid resizing to a size smaller than min_len + # Ensure both width and height are even (if needed) + height += 1 if height % 2 != 0 else 0 + width += 1 if width % 2 != 0 else 0 + + # Ensure dimensions are not smaller than the minimum length height = max(height, min_len) width = max(width, min_len) - return height, width - - -def _resize_output_size_scale_below_upper_bound( - height: int, - width: int, - max_len: Optional[Dict[str, int]] = None) -> Tuple[int, int]: - max_len = max(height, width) if max_len is None else max_len - aspect_ratio = width / height - if width >= height and width > max_len: - width = max_len - height = int(width / aspect_ratio) - elif height > width and height > max_len: - height = max_len - width = int(height * aspect_ratio) - - # Avoid resizing to a size smaller than 1 - height = max(height, 1) - width = max(width, 1) return height, width def _get_resize_output_image_size( - image_size, + image_size: Tuple[int, int], resolution_max_side: int, max_image_size: int = 1820, ) -> Tuple[int, int]: @@ -162,17 +149,16 @@ def _get_resize_output_image_size( # Find the output size, when rescaling the longest edge to max_len and # preserving the aspect ratio - height, width = _resize_output_size_rescale_to_max_len( - height, width, max_len=resolution_max_side) - # Find the output size when scaling the image to be below the max_image_size - height, width = _resize_output_size_scale_below_upper_bound( - height, width, max_len=max_image_size) + height, width = _resize_output_size(height, + width, + max_len=resolution_max_side) + return height, width -def _prompt_split_image(image_seq_len, image_rows, image_cols, - fake_token_around_image, image_token, - global_img_token): +def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int, + fake_token_around_image: str, image_token: str, + global_img_token: str) -> str: """ Prompt with expanded image tokens for when the image is split into patches. @@ -192,16 +178,16 @@ def _prompt_split_image(image_seq_len, image_rows, image_cols, return text_split_images -def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, - global_img_token): +def _prompt_single_image(image_seq_len: int, fake_token_around_image: str, + image_token: str, global_img_token: str): """Prompt with expanded image tokens for a single image.""" return (f"{fake_token_around_image}" + f"{global_img_token}" + f"{image_token}" * image_seq_len + f"{fake_token_around_image}") -def _get_image_prompt_string(image_rows, image_cols, image_seq_len, - fake_token_around_image, image_token, - global_img_token): +def _get_image_prompt_string(image_rows: int, image_cols: int, + image_seq_len: int, fake_token_around_image: str, + image_token: str, global_img_token: str): if image_rows == 0 and image_cols == 0: return _prompt_single_image( image_seq_len, From d632fa81ffa0e461ecd70ae5de2e10f2f6b1e318 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 5 Nov 2024 06:13:36 +0000 Subject: [PATCH 13/19] Update docs Signed-off-by: Jee Jee Li --- docs/source/models/supported_models.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 55835d945b00c..cbfcab26948ac 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -446,6 +446,12 @@ Text Generation - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + * - :code:`Idefics3ForConditionalGeneration` + - Idefics3 + - T + I + - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. + - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` From 84d7428b322c86f5964c60484e1d24f27b987014 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 5 Nov 2024 06:28:22 +0000 Subject: [PATCH 14/19] Update docs Signed-off-by: Jee Jee Li --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index cbfcab26948ac..f9f91fa358ba5 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -451,7 +451,7 @@ Text Generation - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - - - ✅︎ + - * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` From ab8eb7c405984911ab473653686d35359d380677 Mon Sep 17 00:00:00 2001 From: B-201 Date: Tue, 5 Nov 2024 23:17:22 +0800 Subject: [PATCH 15/19] Integrate test code Signed-off-by: B-201 --- .../vision_language/test_idefics3.py | 194 ------------------ .../vision_language/test_models.py | 15 ++ 2 files changed, 15 insertions(+), 194 deletions(-) delete mode 100644 tests/models/decoder_only/vision_language/test_idefics3.py diff --git a/tests/models/decoder_only/vision_language/test_idefics3.py b/tests/models/decoder_only/vision_language/test_idefics3.py deleted file mode 100644 index 669a806a16233..0000000000000 --- a/tests/models/decoder_only/vision_language/test_idefics3.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import List, Optional, Tuple, Type - -import pytest -from transformers import AutoModelForVision2Seq, AutoTokenizer - -from vllm.multimodal.utils import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|begin_of_text|>User:What's the content of the image?\nAssistant:", # noqa: E501 - "cherry_blossom": - "<|begin_of_text|>User:What is the season?\nAssistant:", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|begin_of_text|>User:Describe these images.\nAssistant:" # noqa: E501 - -models = ["HuggingFaceM4/Idefics3-8B-Llama3"] -target_dtype = "half" - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - _, output_str, out_logprobs = vllm_output - - output_str_without_image = output_str - - hf_output_str = output_str_without_image + "<|end_of_text|>" # noqa: E501 - - tokenizer = AutoTokenizer.from_pretrained(model) - hf_output_ids = tokenizer.encode(output_str_without_image) - assert hf_output_ids[0] == 128000 - hf_output_ids = hf_output_ids[1:] - - return hf_output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - task="generate", - max_model_len=8192, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - # use eager mode for hf runner, since phi3_v didn't work with flash_attn - hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, - dtype=dtype, - model_kwargs=hf_model_kwargs, - auto_cls=AutoModelForVision2Seq) as hf_model: - eos_token_id = hf_model.processor.tokenizer.eos_token_id - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - eos_token_id=eos_token_id) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -# Since we use _attn_implementation="eager" for hf_runner, there is more -# significant numerical difference. The basic `logprobs=5` fails to pass. -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index cfd2d61f2b633..ec3311791086a 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -327,6 +327,21 @@ vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, ), + "idefics3": VLMTestInfo( + models=["HuggingFaceM4/Idefics3-8B-Llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + marks=[ + pytest.mark.skipif( + transformers.__version__ < "4.46.0", + reason="Model introduced in HF >= 4.46.0" + ) + ], + ), ### Tensor parallel / multi-gpu broadcast tests "broadcast-chameleon": VLMTestInfo( models=["facebook/chameleon-7b"], From 041c0344f5be1f031381c2b40715a05fe60a85ac Mon Sep 17 00:00:00 2001 From: B-201 Date: Wed, 6 Nov 2024 00:02:57 +0800 Subject: [PATCH 16/19] Fix code format Signed-off-by: B-201 --- vllm/model_executor/models/idefics3.py | 47 ++++++++------------------ 1 file changed, 14 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 4b2f2dda6cbe8..8dee9a3aabcd1 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -47,7 +47,7 @@ # yapf: enable from .interfaces import SupportsMultiModal from .llama import LlamaModel -from .utils import AutoWeightsLoader, flatten_bn +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings logger = init_logger(__name__) @@ -166,23 +166,24 @@ def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int, text_split_images = "" for n_h in range(image_rows): for n_w in range(image_cols): - text_split_images += (f"{fake_token_around_image}" + + text_split_images += (fake_token_around_image + f"" + - f"{image_token}" * image_seq_len) + image_token * image_seq_len) text_split_images += "\n" - text_split_images += (f"\n{fake_token_around_image}" + - f"{global_img_token}" + - f"{image_token}" * image_seq_len + - f"{fake_token_around_image}") + text_split_images += "\n" + _prompt_single_image( + image_seq_len=image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token) return text_split_images def _prompt_single_image(image_seq_len: int, fake_token_around_image: str, image_token: str, global_img_token: str): """Prompt with expanded image tokens for a single image.""" - return (f"{fake_token_around_image}" + f"{global_img_token}" + - f"{image_token}" * image_seq_len + f"{fake_token_around_image}") + return (fake_token_around_image + global_img_token + + image_token * image_seq_len + fake_token_around_image) def _get_image_prompt_string(image_rows: int, image_cols: int, @@ -190,7 +191,7 @@ def _get_image_prompt_string(image_rows: int, image_cols: int, image_token: str, global_img_token: str): if image_rows == 0 and image_cols == 0: return _prompt_single_image( - image_seq_len, + image_seq_len=image_seq_len, fake_token_around_image=fake_token_around_image, image_token=image_token, global_img_token=global_img_token, @@ -523,24 +524,6 @@ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: image_features = self._process_image_pixels(image_input) return self.connector(image_features) - def _merge_multimodal_embeddings( - self, - input_ids: torch.LongTensor, - inputs_embeds: Optional[torch.Tensor], - vision_embeddings: Optional[torch.Tensor], - ): - num_images, _, vision_hidden_size = vision_embeddings.shape - special_image_token_mask = input_ids == self.image_token_id - new_inputs_embeds = inputs_embeds.clone() - reshaped_vision_embeddings = vision_embeddings.view( - -1, vision_hidden_size) - # cast to the dtype of the input_embeds to support quantized models - reshaped_vision_embeddings = reshaped_vision_embeddings.to( - inputs_embeds.dtype) - new_inputs_embeds[ - special_image_token_mask] = reshaped_vision_embeddings - return new_inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -562,11 +545,9 @@ def forward( vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.text_model.get_input_embeddings(input_ids) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - vision_embeddings, - ) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) else: inputs_embeds = self.text_model.get_input_embeddings(input_ids) input_ids = None From ba7e85ca5423df5ab9e6d718d11ac682820e5e5f Mon Sep 17 00:00:00 2001 From: B-201 Date: Wed, 6 Nov 2024 15:34:48 +0800 Subject: [PATCH 17/19] Update example & test Signed-off-by: B-201 --- ...e_inference_vision_language_multi_image.py | 25 +++++++++++++++++++ .../vision_language/test_models.py | 3 ++- vllm/entrypoints/chat_utils.py | 2 ++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index d99684078ff3d..7e883568995a4 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -290,6 +290,30 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: ) +def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: + model_name = "HuggingFaceM4/Idefics3-8B-Llama3" + + # The configuration below has been confirmed to launch on a single L40 GPU. + llm = LLM( + model=model_name, + max_model_len=8192, + max_num_seqs=16, + enforce_eager=True, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|begin_of_text|>User:{placeholders}\n{question}\nAssistant:" # noqa: E501 + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + model_example_map = { "phi3_v": load_phi3v, "h2ovl_chat": load_h2onvl, @@ -298,6 +322,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: "qwen2_vl": load_qwen2_vl, "qwen_vl_chat": load_qwenvl_chat, "mllama": load_mllama, + "idefics3": load_idefics3, } diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index ec3311791086a..3dbfaafb781af 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -339,7 +339,8 @@ pytest.mark.skipif( transformers.__version__ < "4.46.0", reason="Model introduced in HF >= 4.46.0" - ) + ), + large_gpu_mark(min_gb=48), ], ), ### Tensor parallel / multi-gpu broadcast tests diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 8da08d4b2c93c..c1df4d44a637e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -196,6 +196,8 @@ def _placeholder_str(self, modality: ModalityStr, return "<|vision_start|><|image_pad|><|vision_end|>" if model_type == "molmo": return "" + if model_type == "idefics3": + return "" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": From 10415d3169dc2a27b30ead1f8012898a6def8f9e Mon Sep 17 00:00:00 2001 From: B-201 Date: Wed, 6 Nov 2024 15:38:05 +0800 Subject: [PATCH 18/19] Fix code format Signed-off-by: B-201 --- vllm/model_executor/models/idefics3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8dee9a3aabcd1..8671e4bc8d941 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); From e5bb291cd674f47286bb53ea3628c15a83f9444f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 6 Nov 2024 09:08:12 +0000 Subject: [PATCH 19/19] Fix model registry Signed-off-by: Jee Jee Li --- vllm/model_executor/models/idefics3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8671e4bc8d941..e4c98f22fb16f 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -21,7 +21,8 @@ import torch.utils.checkpoint from PIL import Image from torch import nn -from transformers import Idefics3Config +# Temporary solution for transformers below 4.46.0. +from transformers import PretrainedConfig as Idefics3Config from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig