diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 873407fbddd..94a543497bb 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -47,8 +47,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ - use_sdp_causal +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj from transformers.modeling_outputs import BaseModelOutputWithPast @@ -599,7 +598,7 @@ def llama_attention_forward_4_31_quantized( kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) - if not use_sdp_fp8(q_len, key_states.shape[2], query_states): + if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads @@ -1282,7 +1281,7 @@ def llama_attention_forward_4_41_quantized( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] - if not use_sdp_fp8(q_len, key_states.shape[2], query_states): + if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key_states = repeat_kv(key_states, self.num_key_value_groups)\ @@ -1873,7 +1872,7 @@ def llama_attention_forward_4_38_quantized( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] - if not use_sdp_fp8(q_len, key_states.shape[2], query_states): + if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key_states = repeat_kv(key_states, self.num_key_value_groups)\ diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 2694031d586..a3e27763fbc 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -51,9 +51,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ - use_sdp_causal +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv @@ -409,7 +407,7 @@ def mistral_attention_forward_quantized( kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) - if not use_sdp_fp8(q_len, key_states.shape[2], query_states): + if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) @@ -845,7 +843,7 @@ def mistral_attention_forward_4_36_quantized( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] - if not use_sdp_fp8(q_len, key_states.shape[2], query_states): + if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1de802967d1..1f14bf2f376 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -22,7 +22,6 @@ from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\ FP6, ASYM_INT4 -from ipex_llm.transformers.convert import is_deepspeed_available FP8_KV_ALLOC_LENGTH = 512 KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -335,15 +334,6 @@ def use_sdp(q_len, kv_len, head_dim, query_states): ) -def use_sdp_fp8(q_len, kv_len, query_states): - return ( - query_states.device.type == "xpu" - and query_states.dtype in [torch.float, torch.half] # fp32/fp16 - and q_len != kv_len # next token - and q_len <= 32 # lookup - ) - - def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): return ( q_len == kv_len # first token