Skip to content

Commit

Permalink
Add PunicaWrapperHPU to handle LoRA computations in HPU
Browse files Browse the repository at this point in the history
Signed-off-by: Sanju C Sudhakaran <[email protected]>
  • Loading branch information
SanjuCSudhakaran committed Dec 10, 2024
1 parent 0922d0d commit d51d66c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
4 changes: 0 additions & 4 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
VocabParallelEmbedding)
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase

Expand Down Expand Up @@ -259,7 +256,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)

return full_output.view_as(full_output_org)

@classmethod
Expand Down
87 changes: 87 additions & 0 deletions vllm/lora/punica_wrapper/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Optional, Tuple, Union, final

import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)

from .punica_base import PunicaWrapperBase


@final
class PunicaWrapperHPU(PunicaWrapperBase):

def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
# Increasing max_num_batched_tokens by 3x to handle increase in
# tensor size due to padding.
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)

def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)

def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0

for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)

def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
y = y.view_as(y_org)

def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
raise NotImplementedError

def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs,
) -> None:
raise NotImplementedError
5 changes: 5 additions & 0 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError

0 comments on commit d51d66c

Please sign in to comment.