Skip to content

Commit

Permalink
Use new sdp again (#11025)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored May 16, 2024
1 parent 7e29928 commit 59df750
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 95 deletions.
22 changes: 10 additions & 12 deletions python/llm/src/ipex_llm/transformers/models/baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.utils.common.log4Error import invalidInputError
from transformers.utils import logging
logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -166,9 +167,8 @@ def baichuan_attention_forward_7b_quantized(

past_key_value = (key_states, value_states) if use_cache else None

if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
"attention_mask's dtype cannot be bool")

scaling_factor = 1 / math.sqrt(query_states.size(-1))
if query_states.size(2) != 1 or device.type != 'xpu':
Expand Down Expand Up @@ -279,6 +279,9 @@ def baichuan_attention_forward_7b_origin(

past_key_value = (key_states, value_states) if use_cache else None

invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
"attention_mask's dtype cannot be bool")

if xops is not None and self.training:
attn_weights = None
query_states = query_states.transpose(1, 2)
Expand All @@ -296,17 +299,12 @@ def baichuan_attention_forward_7b_origin(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
Expand Down
33 changes: 18 additions & 15 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import use_esimd_sdp
from ipex_llm.transformers.models.utils import use_sdp


import os
Expand Down Expand Up @@ -558,25 +558,28 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
value_layer,
is_causal=True).to(key_layer.dtype)
else:
if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
query_layer.shape[-1], query_layer):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_layer,
key_layer,
value_layer)
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
device=query_layer.device)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
else:
attn_bias = None

if use_sdp(query_layer.shape[2], key_layer.shape[2],
query_layer.shape[-1], query_layer):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias)
context_layer = attn_output.view(query_layer.shape)
else:
head_dim = query_layer.size(-1)
attn = torch.matmul(query_layer.to(key_layer.dtype),
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
device=query_layer.device)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value_layer.dtype)
Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
apply_rotary_pos_emb_cache_freq_xpu
)
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache

from typing import Optional, Tuple, List
Expand Down Expand Up @@ -144,7 +144,7 @@ def attention_forward(
attention_mask)
else:
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import linear_q4_0
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache
Expand Down
66 changes: 2 additions & 64 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,69 +318,6 @@ def use_flash_attention(query, key, attention_mask=None):
return True


def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
if head_dim != 128:
# esimd_sdp only support head_dim = 128 now
return False
elif q_len != 1:
# esimd_sdp only support rest token and q_len == 1 now
return False
elif k_len < 8:
# esimd_sdp will cause wrong output when k_len < 8
return False
elif query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False

device_name = torch.xpu.get_device_name(query_states.device.index)
if device_name.startswith("Intel(R) Arc(TM) A") or \
device_name.startswith("Intel(R) Data Center GPU Flex") or \
device_name.startswith("Intel(R) Data Center GPU Max"):
import linear_fp16_esimd
if not hasattr(linear_fp16_esimd, "sdp_forward"):
return False
else:
return False

if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
# esimd_sdp not support PVC GPU when batch size > 1 for now
return False
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
and is_deepspeed_available:
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
return False
if query_states.shape[0] > 1 and attention_mask is not None:
# for batched input, can't accept attention_mask
# TODO: this check needs some time
if not torch.all(attention_mask.eq(0)):
return False

return True


def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
if query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
elif head_dim not in [64, 96, 128]:
# esimd_sdp only support head_dim = 128 and 64 now
return False
elif q_len == k_len:
# new sdp_fp16 only support rest token now
return False
elif q_len > 32:
# Use new sdp_fp16 only when q_len <= 32
return False

return True


def use_sdp(q_len, kv_len, head_dim, query_states):
return (
query_states.device.type == "xpu"
Expand All @@ -400,9 +337,10 @@ def use_sdp_fp8(q_len, kv_len, query_states):
)


def use_sdp_causal(q_len, kv_len, query_states, training):
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
return (
q_len == kv_len # first token
and head_dim in [64, 96, 128] # for now
and query_states.device.type == "xpu" # GPU
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and not query_states.requires_grad and not training # not training
Expand Down

0 comments on commit 59df750

Please sign in to comment.