Skip to content

Commit

Permalink
[Misc]Further reduce BNB static variable (vllm-project#10597)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored and weilong.yu committed Dec 13, 2024
1 parent 557d233 commit cd3a6f3
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 219 deletions.
218 changes: 130 additions & 88 deletions vllm/model_executor/model_loader/loader.py

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".W_pack.",
".o_proj.",
".down_proj.",
".up_proj.",
".gate_proj.",
".up_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
15 changes: 0 additions & 15 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
34 changes: 0 additions & 34 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down Expand Up @@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
14 changes: 0 additions & 14 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,20 +1104,6 @@ def forward(
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".fc1.",
".fc2.",
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]

embedding_modules = {}
embedding_padding_modules = []
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
}

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}
7 changes: 1 addition & 6 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
embedding_modules = {}
embedding_padding_modules = []

default_bitsandbytes_target_modules = [
".c_attn.",
".c_proj.",
".w1.",
".w2.",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"w2": ("gate_up_proj", 0),
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down

0 comments on commit cd3a6f3

Please sign in to comment.