Skip to content

Commit

Permalink
punica_gpu
Browse files Browse the repository at this point in the history
Signed-off-by: s.kochetkov <[email protected]>
  • Loading branch information
s.kochetkov committed Jan 3, 2025
1 parent b1cdc0f commit 69a2c7f
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.bgmv_sample import bgmv_sample
from vllm.lora.ops.bgmv_embed import bgmv_embed


from .punica_base import PunicaWrapperBase

Expand Down Expand Up @@ -356,3 +359,50 @@ def add_lora_logits(self,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)

def bgmv_sample(self, hidden_states: torch.Tensor,
lm_heads_all: torch.Tensor, lm_head_base: torch.Tensor):
'''
hidden_states - [num_tokens, hidden_dim]
lm_heads_all - [num_loras, vocab_size, hidden_dim]
the same as:
vocab_size=self.lm_head_tensors.shape[-2]
hidden_dim=hidden_states.size(0)
logits = torch.zeros((hidden_dim, vocab_size),
dtype=torch.float32,
device=hidden_states.device)
for i in range(len(hidden_states)):
if indices[i]==-1:
logits[i]=lm_head_base @ hidden_states[i]
else:
logits[i]=self.lm_head_tensors[indices[i]] @ hidden_states[i]
'''

indices = self.sampler_indices

logits = bgmv_sample(hidden_states, lm_heads_all, lm_head_base,
indices)
return logits

def bgmv_embedding(self, tokens: torch.LongTensor,
embed_tokens_all: torch.Tensor,
embed_tokens_base: torch.Tensor) -> torch.Tensor:
'''
embed_tokens_all - [num_loras, vocab_size, hidden_dim]
modules_to_save embeddings
embed_tokens_base - [vocab_size, hidden_dim] - base layer
embeddings will be applied to tokens with index=-1
tokens - [num_tokens]
returns:
embeddings: [num_tokens, hidden_dim]
'''

embeddings = bgmv_embed(tokens,
embed_tokens_all,
embed_tokens_base,
token_indices=self.token_lora_indices.long())

return embeddings

0 comments on commit 69a2c7f

Please sign in to comment.