-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA Support for Ultravox model #11253
Open
thedebugger
wants to merge
17
commits into
vllm-project:main
Choose a base branch
from
thedebugger:svij-ultravox-lora-dec-16
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+150
−7
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
1c55938
WIP: early draft of lora support in Ultravox
thedebugger 5a6b79f
format fixes
thedebugger 3f5996c
Fix lora modules and formatting
thedebugger d1b65eb
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 7367bc2
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 2abf2ab
Done
jeejeelee be87788
Address code review feedback
thedebugger 317fc38
Merge branch 'main' into svij-ultravox-lora-dec-16
thedebugger 4a633d3
Fix formatting and test case
thedebugger 224a65e
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 769f7bd
Done
jeejeelee 907b3c7
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 208e662
Add doc
jeejeelee 1248d5f
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 575b5dc
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee 7cb7eba
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee f483d9a
Optmize unit test
jeejeelee File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import shutil | ||
from os import path | ||
from tempfile import TemporaryDirectory | ||
from typing import List, Tuple | ||
|
||
from huggingface_hub import snapshot_download | ||
from safetensors.torch import load_file, save_file | ||
from transformers import AutoTokenizer | ||
|
||
from vllm.lora.request import LoRARequest | ||
|
||
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3" | ||
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | ||
|
||
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" | ||
|
||
PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!" | ||
|
||
|
||
def llama3_1_8b_chess_lora_path(): | ||
return snapshot_download( | ||
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") | ||
|
||
|
||
# can't use llama lora adapter without module name transformation | ||
# because ultravox nest language model | ||
def transform_module_names_for_ultravox(state_dict): | ||
transformed_state_dict = {} | ||
for key, value in state_dict.items(): | ||
new_key = key.replace("base_model.model", | ||
"base_model.model.language_model") | ||
transformed_state_dict[new_key] = value | ||
return transformed_state_dict | ||
|
||
|
||
def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path): | ||
tensor_file = "adapter_model.safetensors" | ||
state_dict = load_file(path.join(source_repo, tensor_file)) | ||
transformed_state_dict = transform_module_names_for_ultravox(state_dict) | ||
|
||
save_file(transformed_state_dict, path.join(target_path, tensor_file)) | ||
|
||
config_file = "adapter_config.json" | ||
shutil.copyfile(path.join(source_repo, config_file), | ||
path.join(target_path, config_file)) | ||
return target_path | ||
|
||
|
||
def _get_prompt(audio_count, question, placeholder, model_name) -> str: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
placeholder = f"{placeholder}\n" * audio_count | ||
|
||
return tokenizer.apply_chat_template([{ | ||
'role': 'user', | ||
'content': f"{placeholder}{question}" | ||
}], | ||
tokenize=False, | ||
add_generation_prompt=True) | ||
|
||
|
||
def test_ultravox_lora(vllm_runner): | ||
""" | ||
TODO: Train an Ultravox LoRA instead of using a Llama LoRA. | ||
""" | ||
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path() | ||
with TemporaryDirectory() as temp_ultravox_lora_dir: | ||
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora( | ||
llama3_1_8b_chess_lora, temp_ultravox_lora_dir) | ||
with vllm_runner( | ||
ULTRAVOX_MODEL_NAME, | ||
enforce_eager=True, | ||
max_num_seqs=128, | ||
enable_lora=True, | ||
max_loras=4, | ||
max_lora_rank=128, | ||
dtype="bfloat16", | ||
max_model_len=4096, | ||
) as vllm_model: | ||
ultravox_outputs: List[Tuple[ | ||
List[int], str]] = vllm_model.generate_greedy( | ||
[ | ||
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, | ||
ULTRAVOX_MODEL_NAME) | ||
], | ||
256, | ||
lora_request=LoRARequest(str(1), 1, | ||
llama3_1_8b_ultravox_chess_lora), | ||
) | ||
|
||
# run llama with and without lora to compare outputs with above | ||
with vllm_runner( | ||
LLMA_MODEL_NAME, | ||
enforce_eager=True, | ||
max_num_seqs=128, | ||
enable_lora=True, | ||
max_loras=4, | ||
max_lora_rank=128, | ||
dtype="bfloat16", | ||
max_model_len=4096, | ||
) as vllm_model: | ||
llama_outputs_no_lora: List[Tuple[List[int], str]] = ( | ||
vllm_model.generate_greedy( | ||
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)], | ||
256, | ||
)) | ||
|
||
_, llama_no_lora_str = llama_outputs_no_lora[0] | ||
_, ultravox_str = ultravox_outputs[0] | ||
|
||
# verify that text don't match with no lora | ||
assert llama_no_lora_str != ultravox_str | ||
thedebugger marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to this file are not related to this PR, please revert.