Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Support for Ultravox model #11253

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ See [this page](#generative-models) for more information on how to use generativ
- Ultravox
- T + A<sup>E+</sup>
- `fixie-ai/ultravox-v0_3`
-
- ✅︎
- ✅︎
- ✅︎
```
Expand Down
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -779,6 +781,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -787,7 +790,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -823,13 +827,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -844,6 +850,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -858,7 +865,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file are not related to this PR, please revert.

Expand Down
111 changes: 111 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import shutil
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"


def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def transform_module_names_for_ultravox(state_dict):
transformed_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("base_model.model",
"base_model.model.language_model")
transformed_state_dict[new_key] = value
return transformed_state_dict


def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
tensor_file = "adapter_model.safetensors"
state_dict = load_file(path.join(source_repo, tensor_file))
transformed_state_dict = transform_module_names_for_ultravox(state_dict)

save_file(transformed_state_dict, path.join(target_path, tensor_file))

config_file = "adapter_config.json"
shutil.copyfile(path.join(source_repo, config_file),
path.join(target_path, config_file))
return target_path


def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def test_ultravox_lora(vllm_runner):
"""
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
"""
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
with TemporaryDirectory() as temp_ultravox_lora_dir:
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int], str]] = (
vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
256,
))

_, llama_no_lora_str = llama_outputs_no_lora[0]
_, ultravox_str = ultravox_outputs[0]

# verify that text don't match with no lora
assert llama_no_lora_str != ultravox_str
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 26 additions & 2 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
Expand All @@ -31,7 +32,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
Expand Down Expand Up @@ -333,7 +334,20 @@ def forward(
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

packed_modules_mapping = {
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
thedebugger marked this conversation as resolved.
Show resolved Hide resolved

# LoRA specific attributes
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules = [
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
embedding_modules = {}
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
Expand Down Expand Up @@ -381,6 +395,16 @@ def sampler(self):

return get_sampler()

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
Expand Down
Loading