Skip to content

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Dec 3, 2024
1 parent e3e2f39 commit 314bb37
Showing 1 changed file with 58 additions and 36 deletions.
94 changes: 58 additions & 36 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,40 +696,31 @@ def pre_attn_forward(
else:
past_key_value = None

if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
"None",
False,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
kv_cache_on_host = (key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu"))
# CPU SDPA fot next token
if kv_cache_on_host and q_len == 1 and not self.training:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# dispatch to flash attention implementation
attn_output = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor)
attn_output = attn_output.to("hpu", non_blocking=True)

else:
if kv_cache_on_host:
key_states = key_states.to("hpu", non_blocking=True)
value_states = value_states.to("hpu", non_blocking=True)

if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -738,11 +729,42 @@ def pre_attn_forward(
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
"None",
False,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down

0 comments on commit 314bb37

Please sign in to comment.