From 5be4e52b6522113f7276e60b32cb5c1f912de6fd Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 18 Nov 2024 20:57:10 +0800 Subject: [PATCH] [Model][LoRA]LoRA support added for glm-4v (#10418) Signed-off-by: B-201 --- vllm/model_executor/models/chatglm.py | 98 +++++++++++++++++++++------ 1 file changed, 79 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 81e56381eabd8..625e31bb0d368 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -30,6 +30,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs @@ -574,25 +575,8 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) -class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsMultiModal): - packed_modules_mapping = { - "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] - } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ] - embedding_modules = {} - embedding_padding_modules = [] +class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP, + SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -692,3 +676,79 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, combined_weight) loaded_params.add(combined_name) return loaded_params + + +class ChatGLM(ChatGLMBaseModel): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class ChatGLMV(ChatGLMBaseModel): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + # vision + "fc1", + "fc2", + "merged_proj", + "linear_proj" + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + config = vllm_config.model_config.hf_config + # Initialize VL + if hasattr(config, "visual"): + return ChatGLM(vllm_config=vllm_config, prefix=prefix) + # Initialize LLM + else: + return ChatGLMV(vllm_config=vllm_config, prefix=prefix)