Skip to content

Commit

Permalink
[Model][LoRA]LoRA support added for LlamaEmbeddingModel (vllm-project…
Browse files Browse the repository at this point in the history
…#10071)

Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
jeejeelee authored and sumitd2 committed Nov 14, 2024
1 parent e9691b7 commit c21c61e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ Text Embedding
* - :code:`MistralModel`
- Mistral-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
-
- ✅︎
- ✅︎

.. important::
Expand Down
20 changes: 19 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def permute(w: torch.Tensor, n_heads: int):
return name, loaded_weight


class LlamaEmbeddingModel(nn.Module, SupportsPP):
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
Expand All @@ -638,6 +638,19 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []

def __init__(
self,
Expand Down Expand Up @@ -679,3 +692,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config

0 comments on commit c21c61e

Please sign in to comment.