From 52e22cb31849b6929b1d8b291d671b2fa901e00e Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Fri, 23 Aug 2024 13:42:33 -0700 Subject: [PATCH] Get seq len fix propagate (#1291) Co-authored-by: Libin Tang Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/text-generation/README.md | 2 +- optimum/habana/transformers/models/phi/modeling_phi.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 29b754731d..b720936ff4 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -284,7 +284,7 @@ PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_s ### Running with FP8 -Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Intel Neural Compressor (INC), which provides model measurement and quantization capabilities in PyTorch. +Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. From synapse 1.17 / optimum-habana 1.13 release, INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index 07f4d0cd71..1e21735add 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -430,7 +430,12 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device )