diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf37808205..8b247fb9b2388 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -200,6 +200,11 @@ def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") +@pytest.fixture(scope="session") +def qwen2vl_lora_files(): + return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 9a529e27b4cd8..9842203eb15e0 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -4,6 +4,7 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +from vllm.model_executor.models.utils import WeightsMapper lora_lst = [ "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" @@ -71,3 +72,32 @@ def test_load_checkpoints( device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + + +def test_lora_weights_mapping(baichuan_lora_files, ): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "model.": "language_model.model.", + }, ) + + lora_model = LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + for name in lora_model.loras: + assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."]) diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py new file mode 100644 index 0000000000000..c8c720ff0c776 --- /dev/null +++ b/tests/lora/test_qwen2vl.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform + +MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct" + +PROMPT_TEMPLATE = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + "What is in the image?<|im_end|>\n" + "<|im_start|>assistant\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501 + "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.xfail(current_platform.is_rocm(), + reason="Qwen2-VL dependency xformers incompatible with ROCm" + ) +def test_qwen2vl_lora(qwen2vl_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_num_seqs=2, + enable_lora=True, + max_loras=2, + max_lora_rank=16, + trust_remote_code=True, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, + max_model_len=4096, + ) + output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output1[i]) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 70806a77b9fff..f50db8e3b8e10 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -28,7 +28,7 @@ parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.utils import PPMissingLayer +from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -113,13 +113,14 @@ def from_lora_tensors( target_embedding_padding: Optional[int] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( - tensor_name) + tensor_name, weights_mapper) if module_name not in loras: lora_embeddings_tensor = None if embeddings: @@ -187,6 +188,7 @@ def from_local_checkpoint( target_embedding_padding: Optional[int] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -289,7 +291,8 @@ def from_local_checkpoint( embeddings=embeddings, target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules) + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) class LoRAModelManager(AdapterModelManager): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 5876494ce2824..3a84a6ae1c02a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,3 +1,4 @@ +import copy import os import re from typing import List, Optional, Set, Tuple, Type, Union @@ -30,6 +31,8 @@ # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -91,28 +94,54 @@ def replace_submodule(model: nn.Module, module_name: str, return new_module -def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]: +def parse_fine_tuned_lora_name( + name: str, + weights_mapper: Optional[WeightsMapper] = None +) -> Tuple[str, bool, bool]: """Parse the name of lora weights. args: name: the name of the fine-tuned LoRA, e.g. base_model.model.dense1.weight + weights_mapper: maps the name of weight, e.g. + `model.` -> `language_model.model.`, return: Tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. """ + + w_mapper = None + if weights_mapper: + w_mapper = copy.deepcopy(weights_mapper) + # TODO: Currently only supports mapping for prefix, mapping for + # substr and subfix will be supported in the future. + for attr, mapping in [ + ("orig_to_new_substr", w_mapper.orig_to_new_substr), + ("orig_to_new_suffix", w_mapper.orig_to_new_suffix), + ]: + if mapping: + print_warning_once( + f"vLLM currently does not support mapping of LoRA weights " + f"for {mapping}.") + setattr(w_mapper, attr, {}) + + mapper = (lambda name: w_mapper._map_name(name) + if w_mapper is not None else name) parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - return ".".join(parts[2:-2]), parts[-2] == "lora_A", False + new_name = ".".join(parts[2:-2]) + return mapper(new_name), parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False + new_name = ".".join(parts[2:-1]) + return mapper(new_name), parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - return ".".join(parts[2:-2]), False, True + new_name = ".".join(parts[2:-2]) + return mapper(new_name), False, True raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 93a5e27621912..ef8cc5886103e 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -92,6 +92,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: else: expected_lora_modules.append(module) lora_path = get_adapter_absolute_path(lora_request.lora_path) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + hf_to_vllm_mapper = None + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper + lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, @@ -103,7 +111,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, - ) + weights_mapper=hf_to_vllm_mapper) + except Exception as e: raise RuntimeError(f"Loading lora {lora_path} failed") from e if lora.rank > self.lora_config.max_lora_rank: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b38ea923f0bf1..fb97eb1916002 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -901,6 +901,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ] embedding_modules = {} embedding_padding_modules = [] + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1190,11 +1195,6 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - }) loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)