Skip to content

Commit

Permalink
[Model] Add BNB quantization support for Idefics3 (vllm-project#10310)
Browse files Browse the repository at this point in the history
Signed-off-by: B-201 <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
  • Loading branch information
B-201 and jeejeelee authored Nov 14, 2024
1 parent 52b48c1 commit 294bf46
Showing 1 changed file with 61 additions and 7 deletions.
68 changes: 61 additions & 7 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from PIL import Image
from torch import nn
# Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor

from vllm.attention import AttentionMetadata
Expand All @@ -31,6 +32,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand Down Expand Up @@ -374,12 +376,23 @@ def dummy_data_for_idefics3(

class Idefics3SimpleMLP(nn.Module):

def __init__(self, config):
def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**
2)
output_size = config.text_config.hidden_size
self.proj = ReplicatedLinear(input_size, output_size, bias=False)
self.proj = ReplicatedLinear(
input_size,
output_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj"),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out, _ = self.proj(x)
Expand All @@ -388,10 +401,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class Idefics3Connector(nn.Module):

def __init__(self, config):
def __init__(
self,
config: Idefics3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.scale_factor = config.scale_factor
self.modality_projection = Idefics3SimpleMLP(config)
self.modality_projection = Idefics3SimpleMLP(
config,
quant_config,
prefix=maybe_prefix(prefix, "modality_projection"),
)

def pixel_shuffle(self,
x: torch.Tensor,
Expand Down Expand Up @@ -431,9 +453,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics3VisionTransformer(config.vision_config,
quant_config)
self.connector = Idefics3Connector(config)
self.vision_model = Idefics3VisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"))
self.connector = Idefics3Connector(
config,
quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
self.text_model = LlamaModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "text_model"),
Expand Down Expand Up @@ -637,6 +665,32 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj",
"down_proj",
]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

embedding_modules = {}
embedding_padding_modules = []

Expand Down

0 comments on commit 294bf46

Please sign in to comment.