From ad723b5e5f883f8a6e26657a9786461fbc54d691 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 27 Nov 2023 10:19:30 +0800 Subject: [PATCH] clean code --- .../intel_cpu/src/nodes/scaled_attn.cpp | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 29de7a07177c00..dc0349fa614980 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -573,23 +573,16 @@ struct MHASingleToken { template struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor { - PlainTensor q_input; // f32[B, L1, H*S] / [B, H, L1, S] - PlainTensor k_input; // f32[B, L1, H*S] - PlainTensor v_input; // f32[B, L1, H*S] - PlainTensor k_cache; // f32[B, H, max_kvLen, S] - PlainTensor v_cache; // f32[B, H, max_kvLen, S] + PlainTensor q_input; // f32[B, H, L1, S] + PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor beam_table; // i32[B, max_kvLen] - PlainTensor attn_mask; // f32[B, qLen + kvLen] - float scale_input = 0.0f; // f32[B, qLen + kvLen] - PlainTensor cos_tab; // f32[max_kv_len, rotary_dims//2] - PlainTensor sin_tab; // f32[max_kv_len, rotary_dims//2] - - PlainTensor output_emb; // f32[B, L1, H*S] + PlainTensor attn_mask; // f32[[B|1],[H|1], L1|1, L0+L1] + float scale_input = 0.0f; MHAKernel kernel; MHASingleToken kernel_single_token; - PlainTensor m_query_emb; // query with RoPE position embedding size_t B, H, L1, L0, S; Config config; @@ -609,15 +602,19 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt PlainTensor& past_k_output, PlainTensor& past_v_output) { if (config.config.fuse_concat) { + k_input.assert_dims({B, 0, L1, S}, true); + v_input.assert_dims({B, 0, L1, S}, true); auto past_k_idx = inputs.size() - 2; auto past_k_mem = inputs[past_k_idx + 0]; L0 = past_k_mem->getStaticDims()[2]; + // k,v may support multiquery + auto Hk = past_k_mem->getStaticDims()[1]; // [S, B, L0, S] - past_k_output.resize({L0 + L1, B, H, S}, static_cast(outputs[1]->getData())); - past_v_output.resize({L0 + L1, B, H, S}, static_cast(outputs[2]->getData())); + past_k_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[1]->getData())); + past_v_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[2]->getData())); past_k_output = past_k_output.permute({1, 2, 0, 3}); past_v_output = past_v_output.permute({1, 2, 0, 3}); - parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) { std::memcpy(&past_k_output.at({b, h, m + L0, 0}), &k_input.at({b, h, m, 0}), S * sizeof(T)); @@ -627,19 +624,26 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt }); if (!config.skipPastKVCopy) { PlainTensor past_k_input, past_v_input; - past_k_input.resize({L0, B, H, S}, static_cast(past_k_mem->getData())); - past_v_input.resize({L0, B, H, S}, static_cast(inputs[past_k_idx + 1]->getData())); + past_k_input.resize({L0, B, Hk, S}, static_cast(past_k_mem->getData())); + past_v_input.resize({L0, B, Hk, S}, static_cast(inputs[past_k_idx + 1]->getData())); past_k_input = past_k_input.permute({1, 2, 0, 3}); past_v_input = past_v_input.permute({1, 2, 0, 3}); - parallel_for3d(B, H, L0, [&](size_t b, size_t h, size_t m) { - memcpy(&past_k_output.at({b, h, m, 0}), - &past_k_input.at({b, h, m, 0}), - S * sizeof(T)); - memcpy(&past_v_output.at({b, h, m, 0}), - &past_v_input.at({b, h, m, 0}), - S * sizeof(T)); + parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) { + std::memcpy(&past_k_output.at({b, h, m, 0}), + &past_k_input.at({b, h, m, 0}), + S * sizeof(T)); + std::memcpy(&past_v_output.at({b, h, m, 0}), + &past_v_input.at({b, h, m, 0}), + S * sizeof(T)); }); } + } else { + // k,v inputs are already concatenated + L0 = k_input.size(2) - L1; + k_input.assert_dims({B, 0, L0 + L1, S}, true); + v_input.assert_dims({B, 0, L0 + L1, S}, true); + past_k_output = k_input; + past_v_output = v_input; } } @@ -667,7 +671,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } } - // q, k, v: [B, H, L1, S] + // q: [B, H, L1, S] B = q_input.size(0); H = q_input.size(1); L1 = q_input.size(2); @@ -676,24 +680,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt PlainTensor present_key, present_value; concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); - L0 = k_input.size(2) - L1; - ov::intel_cpu::PlainTensor output_emb(outputs[0]); - q_input.assert_dims({B, H, L1, S}); - k_input.assert_dims({B, 0, L0 + L1, S}, true); - v_input.assert_dims({B, 0, L0 + L1, S}, true); - m_query_emb = q_input; - if (!fuse_concat || L1 > 1) { - present_key = k_input; - present_value = v_input; - } - bool auto_causal; bool use_attn_mask; if (fuse_causal_attn) { assert(attn_mask); - attn_mask.assert_dims({B, 1, 1, L0 + L1}); + attn_mask.assert_dims({B, 1, L1, L0 + L1}); auto_causal = true; use_attn_mask = true; } else { @@ -720,7 +713,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt if (L1 > 1) { // multi-token version - kernel(strm, m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, has_out_transpose, auto_causal, scale_input); } else { // 1-token version @@ -728,7 +721,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1 // 2, using float will save the repack cost which typically is required for bf16/int8 opt // 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache - kernel_single_token(m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, beam_table, has_out_transpose, auto_causal, scale_input); } } @@ -851,7 +844,6 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - const auto node = std::dynamic_pointer_cast(op); if (!std::dynamic_pointer_cast(op) && !std::dynamic_pointer_cast(op)) { errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported"; @@ -863,7 +855,14 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptrget_input_size() > 3) { + int orgSDPAInput = static_cast(op->get_input_size()); + const auto node = std::dynamic_pointer_cast(op); + if (node) { + if (node->get_config().fuse_concat) { + orgSDPAInput -= 2; + } + } + if (orgSDPAInput > 3) { inRank = op->get_input_partial_shape(3).size(); if (inRank > 4u) { errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);