Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Nov 27, 2023
1 parent 5faa0cf commit ad723b5
Showing 1 changed file with 40 additions and 41 deletions.
81 changes: 40 additions & 41 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,23 +573,16 @@ struct MHASingleToken {

template <ScaledDotProductAttention::KernelTypes KType, typename T>
struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor {
PlainTensor<T> q_input; // f32[B, L1, H*S] / [B, H, L1, S]
PlainTensor<T> k_input; // f32[B, L1, H*S]
PlainTensor<T> v_input; // f32[B, L1, H*S]
PlainTensor<T> k_cache; // f32[B, H, max_kvLen, S]
PlainTensor<T> v_cache; // f32[B, H, max_kvLen, S]
PlainTensor<T> q_input; // f32[B, H, L1, S]
PlainTensor<T> k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<T> v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<int32_t> beam_table; // i32[B, max_kvLen]
PlainTensor<float> attn_mask; // f32[B, qLen + kvLen]
float scale_input = 0.0f; // f32[B, qLen + kvLen]
PlainTensor<float> cos_tab; // f32[max_kv_len, rotary_dims//2]
PlainTensor<float> sin_tab; // f32[max_kv_len, rotary_dims//2]

PlainTensor<T> output_emb; // f32[B, L1, H*S]
PlainTensor<float> attn_mask; // f32[[B|1],[H|1], L1|1, L0+L1]
float scale_input = 0.0f;

MHAKernel<KType, T> kernel;
MHASingleToken<T> kernel_single_token;

PlainTensor<T> m_query_emb; // query with RoPE position embedding
size_t B, H, L1, L0, S;

Config config;
Expand All @@ -609,15 +602,19 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
PlainTensor<T>& past_k_output,
PlainTensor<T>& past_v_output) {
if (config.config.fuse_concat) {
k_input.assert_dims({B, 0, L1, S}, true);
v_input.assert_dims({B, 0, L1, S}, true);
auto past_k_idx = inputs.size() - 2;
auto past_k_mem = inputs[past_k_idx + 0];
L0 = past_k_mem->getStaticDims()[2];
// k,v may support multiquery
auto Hk = past_k_mem->getStaticDims()[1];
// [S, B, L0, S]
past_k_output.resize({L0 + L1, B, H, S}, static_cast<T*>(outputs[1]->getData()));
past_v_output.resize({L0 + L1, B, H, S}, static_cast<T*>(outputs[2]->getData()));
past_k_output.resize({L0 + L1, B, Hk, S}, static_cast<T*>(outputs[1]->getData()));
past_v_output.resize({L0 + L1, B, Hk, S}, static_cast<T*>(outputs[2]->getData()));
past_k_output = past_k_output.permute({1, 2, 0, 3});
past_v_output = past_v_output.permute({1, 2, 0, 3});
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m + L0, 0}),
&k_input.at({b, h, m, 0}),
S * sizeof(T));
Expand All @@ -627,19 +624,26 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
});
if (!config.skipPastKVCopy) {
PlainTensor<T> past_k_input, past_v_input;
past_k_input.resize({L0, B, H, S}, static_cast<T*>(past_k_mem->getData()));
past_v_input.resize({L0, B, H, S}, static_cast<T*>(inputs[past_k_idx + 1]->getData()));
past_k_input.resize({L0, B, Hk, S}, static_cast<T*>(past_k_mem->getData()));
past_v_input.resize({L0, B, Hk, S}, static_cast<T*>(inputs[past_k_idx + 1]->getData()));
past_k_input = past_k_input.permute({1, 2, 0, 3});
past_v_input = past_v_input.permute({1, 2, 0, 3});
parallel_for3d(B, H, L0, [&](size_t b, size_t h, size_t m) {
memcpy(&past_k_output.at({b, h, m, 0}),
&past_k_input.at({b, h, m, 0}),
S * sizeof(T));
memcpy(&past_v_output.at({b, h, m, 0}),
&past_v_input.at({b, h, m, 0}),
S * sizeof(T));
parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m, 0}),
&past_k_input.at({b, h, m, 0}),
S * sizeof(T));
std::memcpy(&past_v_output.at({b, h, m, 0}),
&past_v_input.at({b, h, m, 0}),
S * sizeof(T));
});
}
} else {
// k,v inputs are already concatenated
L0 = k_input.size(2) - L1;
k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, 0, L0 + L1, S}, true);
past_k_output = k_input;
past_v_output = v_input;
}
}

Expand Down Expand Up @@ -667,7 +671,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
}
}

// q, k, v: [B, H, L1, S]
// q: [B, H, L1, S]
B = q_input.size(0);
H = q_input.size(1);
L1 = q_input.size(2);
Expand All @@ -676,24 +680,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
PlainTensor<T> present_key, present_value;
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);

L0 = k_input.size(2) - L1;

ov::intel_cpu::PlainTensor<T> output_emb(outputs[0]);

q_input.assert_dims({B, H, L1, S});
k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, 0, L0 + L1, S}, true);
m_query_emb = q_input;
if (!fuse_concat || L1 > 1) {
present_key = k_input;
present_value = v_input;
}

bool auto_causal;
bool use_attn_mask;
if (fuse_causal_attn) {
assert(attn_mask);
attn_mask.assert_dims({B, 1, 1, L0 + L1});
attn_mask.assert_dims({B, 1, L1, L0 + L1});
auto_causal = true;
use_attn_mask = true;
} else {
Expand All @@ -720,15 +713,15 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt

if (L1 > 1) {
// multi-token version
kernel(strm, m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
output_emb, has_out_transpose, auto_causal, scale_input);
} else {
// 1-token version
// for second token, using a special AVX2/AVX512 float path:
// 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1
// 2, using float will save the repack cost which typically is required for bf16/int8 opt
// 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache
kernel_single_token(m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
output_emb, beam_table, has_out_transpose, auto_causal, scale_input);
}
}
Expand Down Expand Up @@ -851,7 +844,6 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {

bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op);
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op) &&
!std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op)) {
errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported";
Expand All @@ -863,7 +855,14 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
}
if (op->get_input_size() > 3) {
int orgSDPAInput = static_cast<int>(op->get_input_size());
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op);
if (node) {
if (node->get_config().fuse_concat) {
orgSDPAInput -= 2;
}
}
if (orgSDPAInput > 3) {
inRank = op->get_input_partial_shape(3).size();
if (inRank > 4u) {
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
Expand Down

0 comments on commit ad723b5

Please sign in to comment.