From 69a2c7fdb5ab7af243d49fd2da5fbddc55c962f9 Mon Sep 17 00:00:00 2001 From: "s.kochetkov" Date: Fri, 3 Jan 2025 10:11:20 +0000 Subject: [PATCH] punica_gpu Signed-off-by: s.kochetkov --- vllm/lora/punica_wrapper/punica_gpu.py | 50 ++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index de378df8b3cfa..8f6e513ceb828 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -18,6 +18,9 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink + from vllm.lora.ops.bgmv_sample import bgmv_sample + from vllm.lora.ops.bgmv_embed import bgmv_embed + from .punica_base import PunicaWrapperBase @@ -356,3 +359,50 @@ def add_lora_logits(self, self.sampler_indices, add_inputs=True) y = y.view_as(y_org) + + def bgmv_sample(self, hidden_states: torch.Tensor, + lm_heads_all: torch.Tensor, lm_head_base: torch.Tensor): + ''' + hidden_states - [num_tokens, hidden_dim] + lm_heads_all - [num_loras, vocab_size, hidden_dim] + the same as: + vocab_size=self.lm_head_tensors.shape[-2] + hidden_dim=hidden_states.size(0) + + logits = torch.zeros((hidden_dim, vocab_size), + dtype=torch.float32, + device=hidden_states.device) + + for i in range(len(hidden_states)): + if indices[i]==-1: + logits[i]=lm_head_base @ hidden_states[i] + else: + logits[i]=self.lm_head_tensors[indices[i]] @ hidden_states[i] + ''' + + indices = self.sampler_indices + + logits = bgmv_sample(hidden_states, lm_heads_all, lm_head_base, + indices) + return logits + + def bgmv_embedding(self, tokens: torch.LongTensor, + embed_tokens_all: torch.Tensor, + embed_tokens_base: torch.Tensor) -> torch.Tensor: + ''' + embed_tokens_all - [num_loras, vocab_size, hidden_dim] + modules_to_save embeddings + embed_tokens_base - [vocab_size, hidden_dim] - base layer + embeddings will be applied to tokens with index=-1 + tokens - [num_tokens] + returns: + embeddings: [num_tokens, hidden_dim] + + ''' + + embeddings = bgmv_embed(tokens, + embed_tokens_all, + embed_tokens_base, + token_indices=self.token_lora_indices.long()) + + return embeddings