Skip to content

Commit

Permalink
Get seq len fix propagate (#1291)
Browse files Browse the repository at this point in the history
Co-authored-by: Libin Tang <[email protected]>
Co-authored-by: regisss <[email protected]>
  • Loading branch information
3 people committed Aug 23, 2024
1 parent cbdd335 commit 52e22cb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_s

### Running with FP8

Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Intel Neural Compressor (INC), which provides model measurement and quantization capabilities in PyTorch.
Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. From synapse 1.17 / optimum-habana 1.13 release, INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`.

More information on enabling fp8 in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html
Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,12 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down

0 comments on commit 52e22cb

Please sign in to comment.