From 1c559383dc276ae9fd630105a9aa3da199668e60 Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Mon, 18 Nov 2024 07:01:58 +0000 Subject: [PATCH 1/9] WIP: early draft of lora support in Ultravox Signed-off-by: Sumit Vij --- vllm/model_executor/models/ultravox.py | 29 ++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 0b83684c9bac5..be75ffa3946c9 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -30,10 +30,11 @@ MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings_from_map) @@ -317,7 +318,16 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) -class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + #TODO: not sure what is right thing to do here yet + packed_modules_mapping = {} + #should all llama3 modules be supported here? + #source: https://github.com/fixie-ai/ultravox/blob/812f58c5f50c02589c08668d9afe6e4f8c6d0d74/ultravox/model/ultravox_config.py#L20 + supported_lora_modules = [ + 'linear_k', 'linear_q', 'k_proj', 'q_proj' + ] + embedding_modules = {} + embedding_padding_modules = [] hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) @@ -330,6 +340,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multi_modal_config = multimodal_config assert self.multi_modal_config + #TODO: maybe log a warning if lora config is present in UltravoxConfig? + #TODO: figure out if these prefixes need tweaking to support LoRA and/or + #use LLMWrapper or not like this https://github.com/vllm-project/vllm/pull/7199/files#diff-7b8a4e258637b7c94389c745c449c52137d33cf92957f3e5bcb18a0ee204b21bR807 + self.secondary_weights = [] self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: @@ -365,6 +379,17 @@ def sampler(self): return get_sampler() + # Following PR: https://github.com/vllm-project/vllm/pull/7199/files + # check language_model and audio_tower prefixes + # can't tell if vLLM will apply audio lora or not based on following warning: + # https://github.com/vllm-project/vllm/pull/7199/files#diff-d3df23c3e3bcfe97ee8507061c6de54f0eff23a8c75d7f5999062c42245290f8R1033 + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="language_model", + tower_model="audio_tower") + def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: audio_input = input_features.to(self.audio_tower.dtype) From 5a6b79f79145d6ec79c71b8ad677d4ed34db844c Mon Sep 17 00:00:00 2001 From: Sumit Date: Tue, 19 Nov 2024 05:35:18 +0000 Subject: [PATCH 2/9] format fixes WIP: lora tests Minor tweaks Moar fixes Temp changes Cleanup Add more debugging logs and packed modules Signed-off-by: Sumit Vij --- tests/lora/conftest.py | 8 ++--- tests/lora/test_ultravox.py | 71 +++++++++++++++++++++++++++++++++++++ vllm/assets/audio.py | 7 ++++ vllm/lora/models.py | 26 ++++++++++++++ 4 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 tests/lora/test_ultravox.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 57ebaa424fc59..022a920766188 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -147,17 +147,18 @@ def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. return "yard1/llama-2-7b-sql-lora-test" - @pytest.fixture(scope="session") def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) +@pytest.fixture(scope="session") +def llama3_1_8b_chess_lora(): + return snapshot_download(repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") @pytest.fixture(scope="session") def lora_bias_files(): return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias") - @pytest.fixture(scope="session") def mixtral_lora_files(): # Note: this module has incorrect adapter_config.json to test @@ -213,7 +214,6 @@ def baichuan_zero_lora_files(): # all the lora_B weights are initialized to zero. return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") - @pytest.fixture(scope="session") def baichuan_regex_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex") @@ -223,7 +223,6 @@ def baichuan_regex_lora_files(): def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") - @pytest.fixture(scope="session") def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") @@ -233,7 +232,6 @@ def qwen2vl_lora_files(): def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") - @pytest.fixture(scope="session") def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py new file mode 100644 index 0000000000000..f84f9fb9ab8e4 --- /dev/null +++ b/tests/lora/test_ultravox.py @@ -0,0 +1,71 @@ + +from typing import List + +import pytest + +import vllm + +from transformers import AutoTokenizer +from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform + +MODEL_NAME = "fixie-ai/ultravox-v0_3" + +VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" + +EXPECTED_OUTPUT = [ + "Fool mate" +] + +def _get_prompt(audio_count, question, placeholder): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + placeholder = f"{placeholder}\n" * audio_count + + return tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"{placeholder}{question}" + }], + tokenize=False, + add_generation_prompt=True) + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=1000, + ) + + inputs = [{ + "prompt":_get_prompt(1, "Tell me about a silly chess move in 20 words", VLLM_PLACEHOLDER), + }] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_fixie_lora(llama3_1_8b_chess_lora): + llm = vllm.LLM( + MODEL_NAME, + max_num_seqs=2, + enable_lora=True, + max_loras=4, + max_lora_rank=128, + trust_remote_code=True, + dtype="bfloat16", + max_model_len=4096, + enforce_eager=True + ) + output1 = do_sample(llm, llama3_1_8b_chess_lora, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output1[i]) + return None \ No newline at end of file diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index a46c67ad7e00e..77af766e54d1b 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -20,6 +20,13 @@ class AudioAsset: name: Literal["winning_call", "mary_had_lamb"] + def __init__(self, audio_path=None): + if audio_path is None: + audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + + object.__setattr__(self, '_audio_path', audio_path) + @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9cfcc6bba727f..70f941a52384c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -167,9 +167,14 @@ def from_lora_tensors( loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() + print_v=False for lora in loras.values(): + if "v_proj" in lora.module_name and not print_v: + print_v=True + logger.debug(f"Size of v_proj is: {lora.lora_a.size()}") lora.optimize() + logger.debug(f"Creating loras for {lora_model_id} with following modules {loras.keys()}") return cls(lora_model_id, peft_helper.r, loras, @@ -390,6 +395,8 @@ def activate_adapter( for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: + logger.debug("Setting LoRA. int id: %d, module: %s", + lora_model.id, module_name) module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. bias = module_lora.bias @@ -405,6 +412,8 @@ def activate_adapter( module_lora.embeddings_tensor, module_lora.bias) else: + logger.debug("Reseting lora. int id: %d, module: %s", + lora_model.id, module_name) module.reset_lora(index) return True @@ -461,6 +470,11 @@ def remove_all_adapters(self): def _create_lora_modules(self): for module_name, module in self.model.named_modules( remove_duplicate=False): + + logger.debug( + "Create lora module if applicable %s", + module_name, + ) if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): @@ -506,7 +520,16 @@ def _create_lora_modules(self): # aims to prevent this error if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): + logger.warning( + "%s module will be ignored because it isn't of type BaseLayerWithLoRA", + module_name, + ) continue + + logger.debug( + "Going to apply lora on %s module", + module_name, + ) self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. @@ -522,6 +545,9 @@ def create_dummy_lora( rank: int, scaling_factor: Optional[float], embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + logger.debug( + f"Creating a dummy lora with id: {lora_id}" + ) """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): From 3f5996c7ab661eb9340e84a3b8daf7a0a8d9cfe0 Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Tue, 17 Dec 2024 06:26:52 +0000 Subject: [PATCH 3/9] Fix lora modules and formatting Remove stale comment Add llama lora modules Add llama test case Add test case and log warning on missing lora modules Rollback unwanted changes and format fixes Signed-off-by: Sumit Vij --- tests/conftest.py | 16 +++- tests/lora/conftest.py | 16 +++- tests/lora/test_ultravox.py | 125 +++++++++++++++---------- vllm/assets/audio.py | 7 -- vllm/lora/models.py | 31 ++---- vllm/model_executor/models/ultravox.py | 32 +++---- 6 files changed, 128 insertions(+), 99 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 917151ddcb8d4..c42de316c1c01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -733,6 +733,7 @@ def generate( images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: inputs = self.get_inputs(prompts, images=images, @@ -740,7 +741,8 @@ def generate( audios=audios) req_outputs = self.model.generate(inputs, - sampling_params=sampling_params) + sampling_params=sampling_params, + **kwargs) outputs: List[Tuple[List[List[int]], List[str]]] = [] for req_output in req_outputs: @@ -778,6 +780,7 @@ def generate_w_logprobs( images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, + **kwargs: Any, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: inputs = self.get_inputs(prompts, @@ -786,7 +789,8 @@ def generate_w_logprobs( audios=audios) req_outputs = self.model.generate(inputs, - sampling_params=sampling_params) + sampling_params=sampling_params, + **kwargs) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) @@ -822,13 +826,15 @@ def generate_greedy( images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + **kwargs: Any, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images, videos=videos, - audios=audios) + audios=audios, + **kwargs) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] @@ -843,6 +849,7 @@ def generate_greedy_logprobs( videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( @@ -857,7 +864,8 @@ def generate_greedy_logprobs( greedy_logprobs_params, images=images, audios=audios, - videos=videos) + videos=videos, + **kwargs) def generate_encoder_decoder_greedy_logprobs( self, diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 022a920766188..00f55d621978f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -147,18 +147,29 @@ def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. return "yard1/llama-2-7b-sql-lora-test" + @pytest.fixture(scope="session") def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) + @pytest.fixture(scope="session") def llama3_1_8b_chess_lora(): - return snapshot_download(repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") + return snapshot_download( + repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") + + +@pytest.fixture(scope="session") +def llama3_1_8b_ultravox_chess_lora(): + # ultravox chess lora is result of transformation of above chess llama lora + return snapshot_download(repo_id="thedebugger11/ultravox-chess-lora") + @pytest.fixture(scope="session") def lora_bias_files(): return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias") + @pytest.fixture(scope="session") def mixtral_lora_files(): # Note: this module has incorrect adapter_config.json to test @@ -214,6 +225,7 @@ def baichuan_zero_lora_files(): # all the lora_B weights are initialized to zero. return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") + @pytest.fixture(scope="session") def baichuan_regex_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex") @@ -223,6 +235,7 @@ def baichuan_regex_lora_files(): def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") + @pytest.fixture(scope="session") def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") @@ -232,6 +245,7 @@ def qwen2vl_lora_files(): def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + @pytest.fixture(scope="session") def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py index f84f9fb9ab8e4..f3986c0ba29fc 100644 --- a/tests/lora/test_ultravox.py +++ b/tests/lora/test_ultravox.py @@ -1,24 +1,21 @@ +from typing import List, Tuple -from typing import List +from transformers import AutoTokenizer -import pytest - -import vllm - -from transformers import AutoTokenizer from vllm.lora.request import LoRARequest -from vllm.platforms import current_platform -MODEL_NAME = "fixie-ai/ultravox-v0_3" +from ..models.utils import check_outputs_equal + +ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3" +LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" -EXPECTED_OUTPUT = [ - "Fool mate" -] +PROMPT = "Tell me about a silly chess move in 20 words" + -def _get_prompt(audio_count, question, placeholder): - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +def _get_prompt(audio_count, question, placeholder, model_name) -> str: + tokenizer = AutoTokenizer.from_pretrained(model_name) placeholder = f"{placeholder}\n" * audio_count return tokenizer.apply_chat_template([{ @@ -28,44 +25,74 @@ def _get_prompt(audio_count, question, placeholder): tokenize=False, add_generation_prompt=True) -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - sampling_params = vllm.SamplingParams( - temperature=0, - max_tokens=1000, - ) - inputs = [{ - "prompt":_get_prompt(1, "Tell me about a silly chess move in 20 words", VLLM_PLACEHOLDER), - }] +def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora, + llama3_1_8b_ultravox_chess_lora): + with vllm_runner( + ULTRAVOX_MODEL_NAME, + enforce_eager=True, + max_num_seqs=128, + enable_lora=True, + max_loras=4, + max_lora_rank=128, + dtype="bfloat16", + max_model_len=4096, + ) as vllm_model: + ultravox_outputs: List[Tuple[List[int], + str]] = vllm_model.generate_greedy( + [ + _get_prompt( + 0, PROMPT, VLLM_PLACEHOLDER, + ULTRAVOX_MODEL_NAME) + ], + 256, + lora_request=LoRARequest( + str(1), 1, + llama3_1_8b_ultravox_chess_lora), + ) + + # run llama with and without lora to compare outputs with above + with vllm_runner( + LLMA_MODEL_NAME, + enforce_eager=True, + max_num_seqs=128, + enable_lora=True, + max_loras=4, + max_lora_rank=128, + dtype="bfloat16", + max_model_len=4096, + ) as vllm_model: + llama_outputs_no_lora: List[Tuple[List[int], + str]] = vllm_model.generate_greedy( + [ + _get_prompt( + 0, PROMPT, + VLLM_PLACEHOLDER, + LLMA_MODEL_NAME) + ], + 256, + ) + llama_outputs: List[Tuple[List[int], + str]] = vllm_model.generate_greedy( + [ + _get_prompt(0, PROMPT, + VLLM_PLACEHOLDER, + LLMA_MODEL_NAME) + ], + 256, + lora_request=LoRARequest( + str(1), 1, llama3_1_8b_chess_lora), + ) - outputs = llm.generate( - inputs, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, + check_outputs_equal( + outputs_0_lst=ultravox_outputs, + outputs_1_lst=llama_outputs, + name_0="ultravox", + name_1="llama", ) - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts + _, llama_no_lora_str = llama_outputs_no_lora[0] + _, ultravox_str = ultravox_outputs[0] -def test_fixie_lora(llama3_1_8b_chess_lora): - llm = vllm.LLM( - MODEL_NAME, - max_num_seqs=2, - enable_lora=True, - max_loras=4, - max_lora_rank=128, - trust_remote_code=True, - dtype="bfloat16", - max_model_len=4096, - enforce_eager=True - ) - output1 = do_sample(llm, llama3_1_8b_chess_lora, lora_id=1) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output1[i]) - return None \ No newline at end of file + # verify that text don't match with no lora + assert llama_no_lora_str != ultravox_str diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 77af766e54d1b..a46c67ad7e00e 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -20,13 +20,6 @@ class AudioAsset: name: Literal["winning_call", "mary_had_lamb"] - def __init__(self, audio_path=None): - if audio_path is None: - audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", - s3_prefix=ASSET_DIR) - - object.__setattr__(self, '_audio_path', audio_path) - @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 70f941a52384c..278616c45d8a7 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -167,14 +167,9 @@ def from_lora_tensors( loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() - print_v=False for lora in loras.values(): - if "v_proj" in lora.module_name and not print_v: - print_v=True - logger.debug(f"Size of v_proj is: {lora.lora_a.size()}") lora.optimize() - logger.debug(f"Creating loras for {lora_model_id} with following modules {loras.keys()}") return cls(lora_model_id, peft_helper.r, loras, @@ -392,11 +387,10 @@ def activate_adapter( logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id + missing_modules = [] for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: - logger.debug("Setting LoRA. int id: %d, module: %s", - lora_model.id, module_name) module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. bias = module_lora.bias @@ -412,9 +406,14 @@ def activate_adapter( module_lora.embeddings_tensor, module_lora.bias) else: - logger.debug("Reseting lora. int id: %d, module: %s", - lora_model.id, module_name) + missing_modules.append(module_name) module.reset_lora(index) + + if len(missing_modules) > 0: + logger.warning( + "Lora adapter int id %d is activated but is missing \ + base model modules %s which could impact output", + lora_model.id, missing_modules) return True def _deactivate_adapter(self, lora_id: int): @@ -471,10 +470,6 @@ def _create_lora_modules(self): for module_name, module in self.model.named_modules( remove_duplicate=False): - logger.debug( - "Create lora module if applicable %s", - module_name, - ) if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): @@ -521,15 +516,12 @@ def _create_lora_modules(self): if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): logger.warning( - "%s module will be ignored because it isn't of type BaseLayerWithLoRA", + "%s module will be ignored because it isn't of type \ + BaseLayerWithLoRA", module_name, ) continue - logger.debug( - "Going to apply lora on %s module", - module_name, - ) self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. @@ -545,9 +537,6 @@ def create_dummy_lora( rank: int, scaling_factor: Optional[float], embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: - logger.debug( - f"Creating a dummy lora with id: {lora_id}" - ) """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index be75ffa3946c9..d563aba8a2838 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, @@ -30,11 +31,10 @@ MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of -from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings_from_map) @@ -319,12 +319,18 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - #TODO: not sure what is right thing to do here yet - packed_modules_mapping = {} - #should all llama3 modules be supported here? - #source: https://github.com/fixie-ai/ultravox/blob/812f58c5f50c02589c08668d9afe6e4f8c6d0d74/ultravox/model/ultravox_config.py#L20 + # same as llamaforcasuallm (language model) minus embedding and other + # modules. embedding modules haven't been added as a caution + # since it could affect text but not audio + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + #lm_head is not added for now since it requires logits_processor + # which is missing from ultravox supported_lora_modules = [ - 'linear_k', 'linear_q', 'k_proj', 'q_proj' + "qkv_proj", "o_proj", "gate_up_proj", "down_proj" ] embedding_modules = {} embedding_padding_modules = [] @@ -340,10 +346,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multi_modal_config = multimodal_config assert self.multi_modal_config - #TODO: maybe log a warning if lora config is present in UltravoxConfig? - #TODO: figure out if these prefixes need tweaking to support LoRA and/or - #use LLMWrapper or not like this https://github.com/vllm-project/vllm/pull/7199/files#diff-7b8a4e258637b7c94389c745c449c52137d33cf92957f3e5bcb18a0ee204b21bR807 - self.secondary_weights = [] self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: @@ -379,16 +381,12 @@ def sampler(self): return get_sampler() - # Following PR: https://github.com/vllm-project/vllm/pull/7199/files - # check language_model and audio_tower prefixes - # can't tell if vLLM will apply audio lora or not based on following warning: - # https://github.com/vllm-project/vllm/pull/7199/files#diff-d3df23c3e3bcfe97ee8507061c6de54f0eff23a8c75d7f5999062c42245290f8R1033 def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field(language_model="language_model", - tower_model="audio_tower") + return MultiModelKeys.from_string_field( + language_model="language_model", tower_model="audio_tower") def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: From 2abf2abcab4b6dc170ed30916844c9c500254662 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 10:34:00 +0000 Subject: [PATCH 4/9] Done Signed-off-by: Jee Jee Li --- vllm/model_executor/models/ultravox.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e4b48e562d45f..16b5288ade92a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -335,16 +335,16 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - # same as llamaforcasuallm (language model) minus embedding and other - # modules. embedding modules haven't been added as a caution - # since it could affect text but not audio + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] } + # LoRA specific attributes #lm_head is not added for now since it requires logits_processor # which is missing from ultravox + # TODO : Add LoRA to the audio tower and projector. supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj" ] @@ -402,7 +402,10 @@ def get_mm_mapping(self) -> MultiModelKeys: Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( - language_model="language_model", tower_model="audio_tower") + language_model="language_model.", + connector="multi_modal_projector.", + tower_model="audio_tower.", + ) def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: From be8778814233d588f57bc397b0759f72ed7b253c Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Fri, 10 Jan 2025 21:20:02 +0000 Subject: [PATCH 5/9] Address code review feedback * Remove changes from conftest * Move transformation inside the test * Not include missing module changes in this PR Signed-off-by: Sumit Vij --- tests/lora/conftest.py | 12 ---- tests/lora/test_ultravox.py | 81 ++++++++++++++++++-------- vllm/lora/models.py | 14 ----- vllm/model_executor/models/ultravox.py | 5 +- 4 files changed, 62 insertions(+), 50 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 00f55d621978f..57ebaa424fc59 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -153,18 +153,6 @@ def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) -@pytest.fixture(scope="session") -def llama3_1_8b_chess_lora(): - return snapshot_download( - repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") - - -@pytest.fixture(scope="session") -def llama3_1_8b_ultravox_chess_lora(): - # ultravox chess lora is result of transformation of above chess llama lora - return snapshot_download(repo_id="thedebugger11/ultravox-chess-lora") - - @pytest.fixture(scope="session") def lora_bias_files(): return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias") diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py index f3986c0ba29fc..206599f9940b7 100644 --- a/tests/lora/test_ultravox.py +++ b/tests/lora/test_ultravox.py @@ -1,5 +1,10 @@ +import shutil +from os import path +from tempfile import TemporaryDirectory from typing import List, Tuple +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, save_file from transformers import AutoTokenizer from vllm.lora.request import LoRARequest @@ -14,6 +19,35 @@ PROMPT = "Tell me about a silly chess move in 20 words" +def llama3_1_8b_chess_lora_path(): + return snapshot_download( + repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") + + +# can't use llama lora adapter without module name transformation +# because ultravox nest language model +def transform_module_names_for_ultravox(state_dict): + transformed_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace("base_model.model", + "base_model.model.language_model") + transformed_state_dict[new_key] = value + return transformed_state_dict + + +def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path): + tensor_file = "adapter_model.safetensors" + state_dict = load_file(path.join(source_repo, tensor_file)) + transformed_state_dict = transform_module_names_for_ultravox(state_dict) + + save_file(transformed_state_dict, path.join(target_path, tensor_file)) + + config_file = "adapter_config.json" + shutil.copyfile(path.join(source_repo, config_file), + path.join(target_path, config_file)) + return target_path + + def _get_prompt(audio_count, question, placeholder, model_name) -> str: tokenizer = AutoTokenizer.from_pretrained(model_name) placeholder = f"{placeholder}\n" * audio_count @@ -26,30 +60,31 @@ def _get_prompt(audio_count, question, placeholder, model_name) -> str: add_generation_prompt=True) -def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora, - llama3_1_8b_ultravox_chess_lora): - with vllm_runner( - ULTRAVOX_MODEL_NAME, - enforce_eager=True, - max_num_seqs=128, - enable_lora=True, - max_loras=4, - max_lora_rank=128, - dtype="bfloat16", - max_model_len=4096, - ) as vllm_model: - ultravox_outputs: List[Tuple[List[int], - str]] = vllm_model.generate_greedy( - [ - _get_prompt( - 0, PROMPT, VLLM_PLACEHOLDER, - ULTRAVOX_MODEL_NAME) - ], - 256, - lora_request=LoRARequest( - str(1), 1, +def test_ultravox_lora(vllm_runner, ): + llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path() + with TemporaryDirectory() as temp_ultravox_lora_dir: + llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora( + llama3_1_8b_chess_lora, temp_ultravox_lora_dir) + with vllm_runner( + ULTRAVOX_MODEL_NAME, + enforce_eager=True, + max_num_seqs=128, + enable_lora=True, + max_loras=4, + max_lora_rank=128, + dtype="bfloat16", + max_model_len=4096, + ) as vllm_model: + ultravox_outputs: List[Tuple[ + List[int], str]] = vllm_model.generate_greedy( + [ + _get_prompt(0, PROMPT, VLLM_PLACEHOLDER, + ULTRAVOX_MODEL_NAME) + ], + 256, + lora_request=LoRARequest(str(1), 1, llama3_1_8b_ultravox_chess_lora), - ) + ) # run llama with and without lora to compare outputs with above with vllm_runner( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 966baa39581b9..1226a275f84a9 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -388,7 +388,6 @@ def activate_adapter( logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id - missing_modules = [] for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: @@ -407,14 +406,8 @@ def activate_adapter( module_lora.embeddings_tensor, module_lora.bias) else: - missing_modules.append(module_name) module.reset_lora(index) - if len(missing_modules) > 0: - logger.warning( - "Lora adapter int id %d is activated but is missing \ - base model modules %s which could impact output", - lora_model.id, missing_modules) return True def _deactivate_adapter(self, lora_id: int): @@ -470,7 +463,6 @@ def remove_all_adapters(self): def _create_lora_modules(self): for module_name, module in self.model.named_modules( remove_duplicate=False): - if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): @@ -516,13 +508,7 @@ def _create_lora_modules(self): # aims to prevent this error if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): - logger.warning( - "%s module will be ignored because it isn't of type \ - BaseLayerWithLoRA", - module_name, - ) continue - self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 16b5288ade92a..a0e544c91b570 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -346,7 +346,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): # which is missing from ultravox # TODO : Add LoRA to the audio tower and projector. supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj" + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "lm_head" ] embedding_modules = {} embedding_padding_modules = [] @@ -379,6 +379,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) + # logits_processor is added here to support 'lm_head' LoRA module + # for language model + self.logits_processor = self.language_model.logits_processor if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot From 4a633d3ca44f6e1caaf0c9d7fc1dc3004dce595e Mon Sep 17 00:00:00 2001 From: Sumit Vij Date: Sat, 11 Jan 2025 01:21:08 +0000 Subject: [PATCH 6/9] Fix formatting and test case Signed-off-by: Sumit Vij --- tests/lora/test_ultravox.py | 4 ++-- vllm/lora/models.py | 1 - vllm/model_executor/models/ultravox.py | 2 -- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py index 206599f9940b7..463c289fa3f8b 100644 --- a/tests/lora/test_ultravox.py +++ b/tests/lora/test_ultravox.py @@ -16,7 +16,7 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" -PROMPT = "Tell me about a silly chess move in 20 words" +PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!" def llama3_1_8b_chess_lora_path(): @@ -60,7 +60,7 @@ def _get_prompt(audio_count, question, placeholder, model_name) -> str: add_generation_prompt=True) -def test_ultravox_lora(vllm_runner, ): +def test_ultravox_lora(vllm_runner): llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path() with TemporaryDirectory() as temp_ultravox_lora_dir: llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 1226a275f84a9..5b7225bdc8f37 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -407,7 +407,6 @@ def activate_adapter( module_lora.bias) else: module.reset_lora(index) - return True def _deactivate_adapter(self, lora_id: int): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b28951fef6a9d..ef23672a086d3 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -346,7 +346,6 @@ def forward( return hidden_states - @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, dummy_inputs=UltravoxDummyInputsBuilder @@ -359,7 +358,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): } # LoRA specific attributes - #lm_head is not added for now since it requires logits_processor # which is missing from ultravox # TODO : Add LoRA to the audio tower and projector. supported_lora_modules = [ From 769f7bd758892e054be0c095d22de55cbcfb8283 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 16 Jan 2025 06:09:40 +0000 Subject: [PATCH 7/9] Done Signed-off-by: Jee Jee Li --- vllm/model_executor/models/ultravox.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index af3f4c47b3afc..b1ac7c92a0be9 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -349,10 +349,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): } # LoRA specific attributes - # which is missing from ultravox # TODO : Add LoRA to the audio tower and projector. supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "lm_head" + "qkv_proj", "o_proj", "gate_up_proj", "down_proj" ] embedding_modules = {} embedding_padding_modules = [] @@ -385,9 +384,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) - # logits_processor is added here to support 'lm_head' LoRA module - # for language model - self.logits_processor = self.language_model.logits_processor if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot From 208e662f370d42e0f74629aec527ded12bd2396d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 16 Jan 2025 06:12:54 +0000 Subject: [PATCH 8/9] Add doc Signed-off-by: Jee Jee Li --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 85d844f3d3f55..bf703178e6cca 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -759,7 +759,7 @@ See [this page](#generative-models) for more information on how to use generativ - Ultravox - T + AE+ - `fixie-ai/ultravox-v0_3` - - + - ✅︎ - ✅︎ - ✅︎ ``` From f483d9abd9a71cb631bf4c121d267cc7983267d3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 20 Jan 2025 03:54:26 +0000 Subject: [PATCH 9/9] Optmize unit test Signed-off-by: Jee Jee Li --- tests/lora/test_ultravox.py | 38 ++++++++----------------------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/tests/lora/test_ultravox.py b/tests/lora/test_ultravox.py index 463c289fa3f8b..e0049180710c3 100644 --- a/tests/lora/test_ultravox.py +++ b/tests/lora/test_ultravox.py @@ -9,8 +9,6 @@ from vllm.lora.request import LoRARequest -from ..models.utils import check_outputs_equal - ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3" LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" @@ -61,6 +59,9 @@ def _get_prompt(audio_count, question, placeholder, model_name) -> str: def test_ultravox_lora(vllm_runner): + """ + TODO: Train an Ultravox LoRA instead of using a Llama LoRA. + """ llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path() with TemporaryDirectory() as temp_ultravox_lora_dir: llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora( @@ -97,34 +98,11 @@ def test_ultravox_lora(vllm_runner): dtype="bfloat16", max_model_len=4096, ) as vllm_model: - llama_outputs_no_lora: List[Tuple[List[int], - str]] = vllm_model.generate_greedy( - [ - _get_prompt( - 0, PROMPT, - VLLM_PLACEHOLDER, - LLMA_MODEL_NAME) - ], - 256, - ) - llama_outputs: List[Tuple[List[int], - str]] = vllm_model.generate_greedy( - [ - _get_prompt(0, PROMPT, - VLLM_PLACEHOLDER, - LLMA_MODEL_NAME) - ], - 256, - lora_request=LoRARequest( - str(1), 1, llama3_1_8b_chess_lora), - ) - - check_outputs_equal( - outputs_0_lst=ultravox_outputs, - outputs_1_lst=llama_outputs, - name_0="ultravox", - name_1="llama", - ) + llama_outputs_no_lora: List[Tuple[List[int], str]] = ( + vllm_model.generate_greedy( + [_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)], + 256, + )) _, llama_no_lora_str = llama_outputs_no_lora[0] _, ultravox_str = ultravox_outputs[0]