From b1555d225b6200363489a685ccea4808a38232d9 Mon Sep 17 00:00:00 2001 From: B-201 Date: Wed, 13 Nov 2024 17:25:59 +0800 Subject: [PATCH] [Model][LoRA]LoRA support added for idefics3 (#10281) Signed-off-by: B-201 Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/idefics3.py | 55 +++++++++++++++++++++---- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 58ec3acc6aea5..161733c049bbe 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -450,7 +450,7 @@ Text Generation - Idefics3 - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - - + - ✅︎ - * - :code:`InternVLChatModel` - InternVL2 diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8845b2f58af07..85f23a1da533b 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor @@ -44,7 +45,7 @@ from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer) # yapf: enable -from .interfaces import SupportsMultiModal +from .interfaces import SupportsLoRA, SupportsMultiModal from .llama import LlamaModel from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) @@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict): """ Shape: `(batch_size * num_images, num_channels, height, width)` """ - rows: List[int] - cols: List[int] pixel_attention_mask: Optional[torch.BoolTensor] @@ -356,8 +355,15 @@ def dummy_data_for_idefics3( image_seq_len = processor.image_seq_len max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images + if seq_len - max_llm_image_tokens < 0: + raise RuntimeError( + f"Idefics3 cannot process {num_images} images in a prompt, " + "please increase max_model_len or reduce image limit by " + "--limit-mm-per-prompt.") + seq_data = SequenceData.from_prompt_token_counts( - (hf_config.image_token_id, max_llm_image_tokens), (0, seq_len)) + (hf_config.image_token_id, max_llm_image_tokens), + (0, seq_len - max_llm_image_tokens)) width = height = hf_config.vision_config.image_size image = Image.new("RGB", (width, height), color=0) @@ -463,8 +469,6 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) - rows = kwargs.pop("rows", None) - cols = kwargs.pop("cols", None) pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) if pixel_values is None and image_embeds is None: @@ -489,8 +493,6 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values( flatten_bn(pixel_values, concat=True)), - rows=rows, - cols=cols, pixel_attention_mask=flatten_bn( pixel_attention_mask, concat=True)) @@ -610,7 +612,33 @@ def forward( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) @INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) -class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision_model + "fc1", + "fc2", + "out_proj", + # text_model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -672,3 +700,12 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model.text_model", + connector="model.connector", + tower_model="model.vision_model")