From 2f4fe81f7511b0d2a53d889cd76481edf7972f27 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 25 Jan 2024 08:34:41 -0800 Subject: [PATCH] Fix compile error for torch_compile for llama (#662) --- .../models/llama/modeling_llama.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 30322babbb..9222afd793 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -291,10 +291,13 @@ def pre_attn_forward( attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attn_weights.size() not in [ - (bsz, self.num_heads, q_len, kv_seq_len), - (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), - ]: + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len) and attn_weights.size() != ( + bsz, + self.num_key_value_heads, + self.num_key_value_groups, + q_len, + kv_seq_len, + ): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" @@ -302,7 +305,13 @@ def pre_attn_forward( ) if attention_mask is not None: - if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len) and attention_mask.size() != ( + bsz, + 1, + 1, + q_len, + kv_seq_len, + ): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," f" but is {attention_mask.size()}"