Skip to content

Commit

Permalink
Fix lora modules and formatting
Browse files Browse the repository at this point in the history
Remove stale comment

Add llama lora modules

Add llama test case

Add test case and log warning on missing lora modules

Rollback unwanted changes and format fixes

Signed-off-by: Sumit Vij <[email protected]>
  • Loading branch information
thedebugger committed Jan 1, 2025
1 parent 5a6b79f commit 3f5996c
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 99 deletions.
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -778,6 +780,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -786,7 +789,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -822,13 +826,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -843,6 +849,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -857,7 +864,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Expand Down
16 changes: 15 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,29 @@ def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)


@pytest.fixture(scope="session")
def llama3_1_8b_chess_lora():
return snapshot_download(repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


@pytest.fixture(scope="session")
def llama3_1_8b_ultravox_chess_lora():
# ultravox chess lora is result of transformation of above chess llama lora
return snapshot_download(repo_id="thedebugger11/ultravox-chess-lora")


@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down Expand Up @@ -214,6 +225,7 @@ def baichuan_zero_lora_files():
# all the lora_B weights are initialized to zero.
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
Expand All @@ -223,6 +235,7 @@ def baichuan_regex_lora_files():
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
Expand All @@ -232,6 +245,7 @@ def qwen2vl_lora_files():
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
Expand Down
125 changes: 76 additions & 49 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
from typing import List, Tuple

from typing import List
from transformers import AutoTokenizer

import pytest

import vllm

from transformers import AutoTokenizer
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_NAME = "fixie-ai/ultravox-v0_3"
from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

EXPECTED_OUTPUT = [
"Fool mate"
]
PROMPT = "Tell me about a silly chess move in 20 words"


def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
Expand All @@ -28,44 +25,74 @@ def _get_prompt(audio_count, question, placeholder):
tokenize=False,
add_generation_prompt=True)

def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=1000,
)

inputs = [{
"prompt":_get_prompt(1, "Tell me about a silly chess move in 20 words", VLLM_PLACEHOLDER),
}]
def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora,
llama3_1_8b_ultravox_chess_lora):
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
ultravox_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
)
llama_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1, llama3_1_8b_chess_lora),
)

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
check_outputs_equal(
outputs_0_lst=ultravox_outputs,
outputs_1_lst=llama_outputs,
name_0="ultravox",
name_1="llama",
)
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts

_, llama_no_lora_str = llama_outputs_no_lora[0]
_, ultravox_str = ultravox_outputs[0]

def test_fixie_lora(llama3_1_8b_chess_lora):
llm = vllm.LLM(
MODEL_NAME,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=4096,
enforce_eager=True
)
output1 = do_sample(llm, llama3_1_8b_chess_lora, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
return None
# verify that text don't match with no lora
assert llama_no_lora_str != ultravox_str
7 changes: 0 additions & 7 deletions vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]

def __init__(self, audio_path=None):
if audio_path is None:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)

object.__setattr__(self, '_audio_path', audio_path)

@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
Expand Down
31 changes: 10 additions & 21 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,9 @@ def from_lora_tensors(
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()

print_v=False
for lora in loras.values():
if "v_proj" in lora.module_name and not print_v:
print_v=True
logger.debug(f"Size of v_proj is: {lora.lora_a.size()}")
lora.optimize()

logger.debug(f"Creating loras for {lora_model_id} with following modules {loras.keys()}")
return cls(lora_model_id,
peft_helper.r,
loras,
Expand Down Expand Up @@ -392,11 +387,10 @@ def activate_adapter(
logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
missing_modules = []
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
logger.debug("Setting LoRA. int id: %d, module: %s",
lora_model.id, module_name)
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
Expand All @@ -412,9 +406,14 @@ def activate_adapter(
module_lora.embeddings_tensor,
module_lora.bias)
else:
logger.debug("Reseting lora. int id: %d, module: %s",
lora_model.id, module_name)
missing_modules.append(module_name)
module.reset_lora(index)

if len(missing_modules) > 0:
logger.warning(
"Lora adapter int id %d is activated but is missing \
base model modules %s which could impact output",
lora_model.id, missing_modules)
return True

def _deactivate_adapter(self, lora_id: int):
Expand Down Expand Up @@ -471,10 +470,6 @@ def _create_lora_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):

logger.debug(
"Create lora module if applicable %s",
module_name,
)
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
Expand Down Expand Up @@ -521,15 +516,12 @@ def _create_lora_modules(self):
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
logger.warning(
"%s module will be ignored because it isn't of type BaseLayerWithLoRA",
"%s module will be ignored because it isn't of type \
BaseLayerWithLoRA",
module_name,
)
continue

logger.debug(
"Going to apply lora on %s module",
module_name,
)
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
Expand All @@ -545,9 +537,6 @@ def create_dummy_lora(
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
logger.debug(
f"Creating a dummy lora with id: {lora_id}"
)
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
Expand Down
Loading

0 comments on commit 3f5996c

Please sign in to comment.