From 42c7f66a386b2243dcd313feed7dec4e1d508167 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Mon, 22 Jul 2024 15:42:40 -0700 Subject: [PATCH] [Core] Support dynamically loading Lora adapter from HuggingFace (#6234) Co-authored-by: Antoni Baum --- tests/core/test_scheduler.py | 4 +- tests/lora/conftest.py | 10 +++- tests/lora/test_long_context.py | 2 +- tests/lora/test_lora_huggingface.py | 39 ++++++++++++++++ tests/lora/test_utils.py | 57 ++++++++++++++++++++++- vllm/entrypoints/openai/serving_engine.py | 4 +- vllm/lora/request.py | 42 +++++++++++++++-- vllm/lora/utils.py | 47 +++++++++++++++++++ vllm/lora/worker_manager.py | 7 +-- vllm/transformers_utils/tokenizer.py | 5 +- vllm/worker/model_runner.py | 2 +- 11 files changed, 201 insertions(+), 18 deletions(-) create mode 100644 tests/lora/test_lora_huggingface.py diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index bae958211cb7b..4ca2260b5e017 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora(): lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, - lora_local_path="abc")) + lora_path="abc")) waiting.append(seq_group) # Add two more requests to verify lora is prioritized. # 0: Lora, 1: Lora, 2: regular, 3: regular @@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras(): lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, - lora_local_path="abc")) + lora_path="abc")) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index bda123bf13139..0bcae5b0c96dc 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module: @pytest.fixture(scope="session") -def sql_lora_files(): - return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") +def sql_lora_huggingface_id(): + # huggingface repo id is used to test lora runtime downloading. + return "yard1/llama-2-7b-sql-lora-test" + + +@pytest.fixture(scope="session") +def sql_lora_files(sql_lora_huggingface_id): + return snapshot_download(repo_id=sql_lora_huggingface_id) @pytest.fixture(scope="session") diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 853fd9fb3ce7a..389a3ccbc17ec 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos): context_len = long_context_infos[lora_id]["context_length"] scaling_factor = context_len_to_scaling_factor[context_len] return LoRARequest(context_len, lora_id, - long_context_infos[lora_id]["lora"], + long_context_infos[lora_id]["lora"], None, 4096 * scaling_factor) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py new file mode 100644 index 0000000000000..e2daf9d135113 --- /dev/null +++ b/tests/lora/test_lora_huggingface.py @@ -0,0 +1,39 @@ +from typing import List + +import pytest + +from vllm.lora.models import LoRAModel +from vllm.lora.utils import get_adapter_absolute_path +from vllm.model_executor.models.llama import LlamaForCausalLM + +# Provide absolute path and huggingface lora ids +lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] + + +@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name) +def test_load_checkpoints_from_huggingface(lora_fixture_name, request): + lora_name = request.getfixturevalue(lora_fixture_name) + supported_lora_modules = LlamaForCausalLM.supported_lora_modules + packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping + embedding_modules = LlamaForCausalLM.embedding_modules + embed_padding_modules = LlamaForCausalLM.embedding_padding_modules + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + lora_path = get_adapter_absolute_path(lora_name) + + # lora loading should work for either absolute path and hugggingface id. + lora_model = LoRAModel.from_local_checkpoint( + lora_path, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) + + # Assertions to ensure the model is loaded correctly + assert lora_model is not None, "LoRAModel is not loaded correctly" diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 4ff9715b4ca8d..db02bacdb6439 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,9 +1,12 @@ from collections import OrderedDict +from unittest.mock import patch import pytest +from huggingface_hub.utils import HfHubHTTPError from torch import nn -from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule +from vllm.lora.utils import (get_adapter_absolute_path, + parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache @@ -182,3 +185,55 @@ def test_lru_cache(): assert 2 in cache assert 4 in cache assert 6 in cache + + +# Unit tests for get_adapter_absolute_path +@patch('os.path.isabs') +def test_get_adapter_absolute_path_absolute(mock_isabs): + path = '/absolute/path/to/lora' + mock_isabs.return_value = True + assert get_adapter_absolute_path(path) == path + + +@patch('os.path.expanduser') +def test_get_adapter_absolute_path_expanduser(mock_expanduser): + # Path with ~ that needs to be expanded + path = '~/relative/path/to/lora' + absolute_path = '/home/user/relative/path/to/lora' + mock_expanduser.return_value = absolute_path + assert get_adapter_absolute_path(path) == absolute_path + + +@patch('os.path.exists') +@patch('os.path.abspath') +def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist): + # Relative path that exists locally + path = 'relative/path/to/lora' + absolute_path = '/absolute/path/to/lora' + mock_exist.return_value = True + mock_abspath.return_value = absolute_path + assert get_adapter_absolute_path(path) == absolute_path + + +@patch('huggingface_hub.snapshot_download') +@patch('os.path.exists') +def test_get_adapter_absolute_path_huggingface(mock_exist, + mock_snapshot_download): + # Hugging Face model identifier + path = 'org/repo' + absolute_path = '/mock/snapshot/path' + mock_exist.return_value = False + mock_snapshot_download.return_value = absolute_path + assert get_adapter_absolute_path(path) == absolute_path + + +@patch('huggingface_hub.snapshot_download') +@patch('os.path.exists') +def test_get_adapter_absolute_path_huggingface_error(mock_exist, + mock_snapshot_download): + # Hugging Face model identifier with download error + path = 'org/repo' + mock_exist.return_value = False + mock_snapshot_download.side_effect = HfHubHTTPError( + "failed to query model info") + assert get_adapter_absolute_path(path) == path diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 7578dc9dc3c0c..8c6bd10b9b4d4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -43,7 +43,7 @@ class PromptAdapterPath: @dataclass class LoRAModulePath: name: str - local_path: str + path: str AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -83,7 +83,7 @@ def __init__( LoRARequest( lora_name=lora.name, lora_int_id=i, - lora_local_path=lora.local_path, + lora_path=lora.path, ) for i, lora in enumerate(lora_modules, start=1) ] diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 2d10d037760e2..5d791424fbe6e 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +import warnings +from dataclasses import dataclass, field from typing import Optional from vllm.adapter_commons.request import AdapterRequest @@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest): lora_name: str lora_int_id: int - lora_local_path: str + lora_path: str = "" + lora_local_path: Optional[str] = field(default=None, repr=False) long_lora_max_len: Optional[int] = None __hash__ = AdapterRequest.__hash__ + def __post_init__(self): + if 'lora_local_path' in self.__dict__: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + @property def adapter_id(self): return self.lora_int_id @@ -32,6 +48,26 @@ def adapter_id(self): def name(self): return self.lora_name + @property + def path(self): + return self.lora_path + @property def local_path(self): - return self.lora_local_path + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ab3b99eee6fc1..4513337299e16 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,5 +1,9 @@ +import os from typing import List, Optional, Set, Tuple, Type +import huggingface_hub +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, RepositoryNotFoundError) from torch import nn from transformers import PretrainedConfig @@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" raise ValueError(f"{name} is unsupported LoRA weight") + + +def get_adapter_absolute_path(lora_path: str) -> str: + """ + Resolves the given lora_path to an absolute local path. + + If the lora_path is identified as a Hugging Face model identifier, + it will download the model and return the local snapshot path. + Otherwise, it treats the lora_path as a local file path and + converts it to an absolute path. + + Parameters: + lora_path (str): The path to the lora model, which can be an absolute path, + a relative path, or a Hugging Face model identifier. + + Returns: + str: The resolved absolute local path to the lora model. + """ + + # Check if the path is an absolute path. Return it no matter exists or not. + if os.path.isabs(lora_path): + return lora_path + + # If the path starts with ~, expand the user home directory. + if lora_path.startswith('~'): + return os.path.expanduser(lora_path) + + # Check if the expanded relative path exists locally. + if os.path.exists(lora_path): + return os.path.abspath(lora_path) + + # If the path does not exist locally, assume it's a Hugging Face repo. + try: + local_snapshot_path = huggingface_hub.snapshot_download( + repo_id=lora_path) + except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, + HFValidationError): + # Handle errors that may occur during the download + # Return original path instead instead of throwing error here + logger.exception("Error downloading the HuggingFace model") + return lora_path + + return local_snapshot_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3d0ef4252b024..724c308a07a27 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -13,6 +13,7 @@ from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path logger = init_logger(__name__) @@ -89,8 +90,9 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: packed_modules_mapping[module]) else: expected_lora_modules.append(module) + lora_path = get_adapter_absolute_path(lora_request.lora_path) lora = self._lora_model_cls.from_local_checkpoint( - lora_request.lora_local_path, + lora_path, expected_lora_modules, max_position_embeddings=self.max_position_embeddings, lora_model_id=lora_request.lora_int_id, @@ -102,8 +104,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: embedding_padding_modules=self.embedding_padding_modules, ) except Exception as e: - raise RuntimeError( - f"Loading lora {lora_request.lora_local_path} failed") from e + raise RuntimeError(f"Loading lora {lora_path} failed") from e if lora.rank > self.lora_config.max_lora_rank: raise ValueError( f"LoRA rank {lora.rank} is greater than max_lora_rank " diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 7553249544211..c515f46ecc299 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -137,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, if lora_request is None: return None try: - tokenizer = get_tokenizer(lora_request.lora_local_path, *args, - **kwargs) + tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) except OSError as e: # No tokenizer was found in the LoRA folder, # use base model tokenizer logger.warning( "No tokenizer found in %s, using base model tokenizer instead. " - "(Exception: %s)", lora_request.lora_local_path, e) + "(Exception: %s)", lora_request.lora_path, e) tokenizer = None return tokenizer diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d810443665024..31e9fc1eed548 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -691,7 +691,7 @@ def profile_run(self) -> None: dummy_lora_request = LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, - lora_local_path="/not/a/real/path", + lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK)