Skip to content

Commit

Permalink
Enable LoRA support for Intel Gaudi
Browse files Browse the repository at this point in the history
Signed-off-by: Sanju C Sudhakaran <[email protected]>
  • Loading branch information
SanjuCSudhakaran committed Dec 10, 2024
1 parent d1f6d1c commit 0922d0d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
7 changes: 7 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase
Expand Down Expand Up @@ -255,6 +259,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)

return full_output.view_as(full_output_org)

@classmethod
Expand Down Expand Up @@ -1068,6 +1073,8 @@ def _get_logits(
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
if current_platform.is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
Expand Down
17 changes: 4 additions & 13 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,11 +1282,9 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1]
max_seq_len = min(
self.bucketing_global_state.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
Expand All @@ -1304,7 +1302,6 @@ def warmup_scenario(self,
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
Expand All @@ -1326,16 +1323,10 @@ def warmup_scenario(self,
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
for idx in range(batch_size)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
if self.lora_config and not is_lora_profile_run:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size * seq_len,
prompt_mapping=[0] * batch_size * seq_len,
is_prefill=is_prompt))
self.set_active_loras(set(), lora_mapping)
if is_prompt:
seqs = [
self.create_dummy_seq_group_metadata(
Expand Down

0 comments on commit 0922d0d

Please sign in to comment.