Skip to content

Commit

Permalink
format fixes
Browse files Browse the repository at this point in the history
WIP: lora tests

Minor tweaks

Moar fixes

Temp changes

Cleanup

Add more debugging logs and packed modules
  • Loading branch information
thedebugger committed Dec 31, 2024
1 parent 37d18c8 commit 9035ed3
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,18 @@ def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)

@pytest.fixture(scope="session")
def llama3_1_8b_chess_lora():
return snapshot_download(repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")

@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down Expand Up @@ -213,7 +214,6 @@ def baichuan_zero_lora_files():
# all the lora_B weights are initialized to zero.
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
Expand All @@ -223,7 +223,6 @@ def baichuan_regex_lora_files():
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
Expand All @@ -233,7 +232,6 @@ def qwen2vl_lora_files():
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
Expand Down
71 changes: 71 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@

from typing import List

import pytest

import vllm

from transformers import AutoTokenizer
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_NAME = "fixie-ai/ultravox-v0_3"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

EXPECTED_OUTPUT = [
"Fool mate"
]

def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)

def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=1000,
)

inputs = [{
"prompt":_get_prompt(1, "Tell me about a silly chess move in 20 words", VLLM_PLACEHOLDER),
}]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def test_fixie_lora(llama3_1_8b_chess_lora):
llm = vllm.LLM(
MODEL_NAME,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=4096,
enforce_eager=True
)
output1 = do_sample(llm, llama3_1_8b_chess_lora, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
return None
7 changes: 7 additions & 0 deletions vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]

def __init__(self, audio_path=None):
if audio_path is None:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)

object.__setattr__(self, '_audio_path', audio_path)

@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
Expand Down
26 changes: 26 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ def from_lora_tensors(
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()

print_v=False
for lora in loras.values():
if "v_proj" in lora.module_name and not print_v:
print_v=True
logger.debug(f"Size of v_proj is: {lora.lora_a.size()}")
lora.optimize()

logger.debug(f"Creating loras for {lora_model_id} with following modules {loras.keys()}")
return cls(lora_model_id,
peft_helper.r,
loras,
Expand Down Expand Up @@ -390,6 +395,8 @@ def activate_adapter(
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
logger.debug("Setting LoRA. int id: %d, module: %s",
lora_model.id, module_name)
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
Expand All @@ -405,6 +412,8 @@ def activate_adapter(
module_lora.embeddings_tensor,
module_lora.bias)
else:
logger.debug("Reseting lora. int id: %d, module: %s",
lora_model.id, module_name)
module.reset_lora(index)
return True

Expand Down Expand Up @@ -461,6 +470,11 @@ def remove_all_adapters(self):
def _create_lora_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):

logger.debug(
"Create lora module if applicable %s",
module_name,
)
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
Expand Down Expand Up @@ -506,7 +520,16 @@ def _create_lora_modules(self):
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
logger.warning(
"%s module will be ignored because it isn't of type BaseLayerWithLoRA",
module_name,
)
continue

logger.debug(
"Going to apply lora on %s module",
module_name,
)
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
Expand All @@ -522,6 +545,9 @@ def create_dummy_lora(
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
logger.debug(
f"Creating a dummy lora with id: {lora_id}"
)
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
Expand Down

0 comments on commit 9035ed3

Please sign in to comment.