Skip to content

Commit

Permalink
[Model][LoRA]LoRA support added for idefics3 (vllm-project#10281)
Browse files Browse the repository at this point in the history
Signed-off-by: B-201 <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
B-201 authored and mfournioux committed Nov 20, 2024
1 parent d4b31b3 commit b1555d2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Text Generation
- Idefics3
- T + I
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
-
- ✅︎
-
* - :code:`InternVLChatModel`
- InternVL2
Expand Down
55 changes: 46 additions & 9 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
Expand All @@ -44,7 +45,7 @@
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .interfaces import SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
Expand All @@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
rows: List[int]
cols: List[int]
pixel_attention_mask: Optional[torch.BoolTensor]


Expand Down Expand Up @@ -356,8 +355,15 @@ def dummy_data_for_idefics3(
image_seq_len = processor.image_seq_len
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images

if seq_len - max_llm_image_tokens < 0:
raise RuntimeError(
f"Idefics3 cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")

seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
(hf_config.image_token_id, max_llm_image_tokens),
(0, seq_len - max_llm_image_tokens))

width = height = hf_config.vision_config.image_size
image = Image.new("RGB", (width, height), color=0)
Expand Down Expand Up @@ -463,8 +469,6 @@ def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
rows = kwargs.pop("rows", None)
cols = kwargs.pop("cols", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)

if pixel_values is None and image_embeds is None:
Expand All @@ -489,8 +493,6 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(
flatten_bn(pixel_values,
concat=True)),
rows=rows,
cols=cols,
pixel_attention_mask=flatten_bn(
pixel_attention_mask,
concat=True))
Expand Down Expand Up @@ -610,7 +612,33 @@ def forward(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -672,3 +700,12 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.text_model",
connector="model.connector",
tower_model="model.vision_model")

0 comments on commit b1555d2

Please sign in to comment.