Skip to content

Commit

Permalink
Fix compile error for torch_compile for llama (huggingface#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiminha authored Jan 25, 2024
1 parent 9841f1e commit 2f4fe81
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,27 @@ def pre_attn_forward(

attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor

if attn_weights.size() not in [
(bsz, self.num_heads, q_len, kv_seq_len),
(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len),
]:
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len) and attn_weights.size() != (
bsz,
self.num_key_value_heads,
self.num_key_value_groups,
q_len,
kv_seq_len,
):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or"
f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len) and attention_mask.size() != (
bsz,
1,
1,
q_len,
kv_seq_len,
):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
Expand Down

0 comments on commit 2f4fe81

Please sign in to comment.