From a81f36b704dafb52c5cd6e71e9c9ba24f272f0f3 Mon Sep 17 00:00:00 2001 From: ceciliapeng Date: Sun, 19 Nov 2023 18:16:46 -0800 Subject: [PATCH] choose fp16 for kvcaches in SDPA node. OV_ENABLE_SDPA_KVCACHE_FP16 (default false) --- .../compile_flags/os_flags.cmake | 4 +- src/plugins/intel_cpu/CMakeLists.txt | 1 + src/plugins/intel_cpu/src/graph.cpp | 15 ++++ .../nodes/kernels/scaled_attn/acc_value.cpp | 10 ++- .../src/nodes/kernels/scaled_attn/common.hpp | 22 +++++ .../nodes/kernels/scaled_attn/dot_product.cpp | 31 +++++-- .../nodes/kernels/scaled_attn/dot_product.hpp | 2 +- src/plugins/intel_cpu/src/nodes/memory.cpp | 2 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 87 +++++++++++-------- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 2 +- .../subgraph_tests/src/concat_sdp.cpp | 2 +- .../tests/unit/kernel/scaled_attn_test.cpp | 0 12 files changed, 127 insertions(+), 51 deletions(-) create mode 100644 src/plugins/intel_cpu/tests/unit/kernel/scaled_attn_test.cpp diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index c0c878e0183eb0..2c621d93425f4b 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -125,7 +125,7 @@ macro(ov_avx2_optimization_flags flags) set(${flags} -xCORE-AVX2) endif() elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX) - set(${flags} -mavx2 -mfma) + set(${flags} -mavx2 -mfma -mf16c) else() message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") endif() @@ -147,7 +147,7 @@ macro(ov_avx512_optimization_flags flags) set(${flags} -xCOMMON-AVX512) endif() elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX) - set(${flags} -mavx512f -mfma) + set(${flags} -mavx512f -mfma -mf16c) else() message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") endif() diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index d65c8661fcef0d..3156df1122b9a7 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -180,6 +180,7 @@ cross_compiled_file(${TARGET_NAME} NAME attn_reduce NAMESPACE InferenceEngine::Extensions::Cpu::XARCH ) + # system dependencies must go last target_link_libraries(${TARGET_NAME} PRIVATE openvino::pugixml) ov_set_threading_interface_for(${TARGET_NAME}) diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index ee2c3f45320465..7c000ef5cbbf18 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -312,6 +312,21 @@ void Graph::InitDescriptors() { OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.selectOptimalPrimitiveDescriptor); DEBUG_LOG("Select optimal primitive descriptors for node: ", node->getName()); node->selectOptimalPrimitiveDescriptor(); +#ifdef CPU_DEBUG_CAPS + const auto& SPDs = node->getSupportedPrimitiveDescriptors(); + for (size_t i = 0; i < SPDs.size(); i++) { + DEBUG_LOG("#", + node->getExecIndex(), + " ", + node->getName(), + " SupportedPrimitiveDescriptors [", + i, + "/", + SPDs.size(), + "]: \n", + SPDs[i]); + } +#endif } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp index 994fb55e971525..f51232e8b01e08 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp @@ -55,13 +55,19 @@ void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, si auto v_ptr = static_cast(vs[i]); attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len); } - } else { - assert(input_precision == ov::element::bf16); + } else if (input_precision == ov::element::bf16) { for (size_t i = 0; i < vec_num; i++) { auto out_ptr = outs[i]; auto v_ptr = static_cast(vs[i]); attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len); } + } else { + assert(input_precision == ov::element::f16); + for (size_t i = 0; i < vec_num; i++) { + auto out_ptr = outs[i]; + auto v_ptr = static_cast(vs[i]); + attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len); + } } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index b9d4ecbba2bf8c..bae3e19771002a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -7,6 +7,9 @@ #include #include +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" + namespace InferenceEngine { namespace Extensions { namespace Cpu { @@ -55,6 +58,15 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); } + + inline __m512 mm512_uni_loadu_ps(ov::float16* a) { + auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a)); + return _mm512_cvtph_ps(vec_f16); + } + inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) { + __m256i vec_f16 = _mm512_cvtps_ph(v, 0); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16); + } #endif #ifdef HAVE_AVX2 @@ -87,6 +99,16 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); _mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o); } + inline __m256 mm256_uni_loadu_ps(ov::float16* a) { + auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a)); + auto o = _mm256_cvtph_ps(vec_f16); + return o; + } + inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) { + __m128i vec_f16 = _mm256_cvtps_ph(v, 0); // FIXME: rounding + _mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16); + } + inline void hsum(__m256& x) { __m256 y; // x: 0 1 2 3 4 5 6 7 y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp index 3a963b68842d16..9de3884fdacbf7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp @@ -23,8 +23,8 @@ namespace Extensions { namespace Cpu { namespace XARCH { -template -float dot_product_inner(T* a, T* b, size_t n) { +template +float dot_product_inner(TA* a, TB* b, size_t n) { size_t i = 0; float sum = 0.0f; #if defined(HAVE_AVX512F) @@ -51,15 +51,28 @@ float dot_product_inner(T* a, T* b, size_t n) { return sum; } -void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision) { - if (input_precision == ov::element::f32) { - for (size_t i = 0; i < vec_num; i++) { - auto a_ptr = static_cast(a[i]); - auto b_ptr = static_cast(b[i]); - auto c_ptr = static_cast(c[i]); - c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len); +void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, + ov::element::Type a_precision, ov::element::Type b_precision) { + if (a_precision == ov::element::f32) { + if (b_precision == ov::element::f32) { + for (size_t i = 0; i < vec_num; i++) { + auto a_ptr = static_cast(a[i]); + auto b_ptr = static_cast(b[i]); + auto c_ptr = static_cast(c[i]); + c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len); + } + } else { + assert(b_precision == ov::element::f16); + for (size_t i = 0; i < vec_num; i++) { + auto a_ptr = static_cast(a[i]); + auto b_ptr = static_cast(b[i]); + auto c_ptr = static_cast(c[i]); + c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len); + } } } else { + assert(a_precision == ov::element::bf16); + assert(b_precision == ov::element::bf16); for (size_t i = 0; i < vec_num; i++) { auto a_ptr = static_cast(a[i]); auto b_ptr = static_cast(b[i]); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp index 161fd9b890cda2..3a402540f6ac4e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp @@ -13,7 +13,7 @@ namespace Extensions { namespace Cpu { namespace XARCH { -void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision); +void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type precision_a, ov::element::Type precision_b); } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index cf215313ed2b9c..0afa69d422a070 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -611,4 +611,4 @@ MemStatePtr MemoryInputSDPA::makeState() const { } // namespace node } // namespace intel_cpu -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index dc0349fa614980..6d349a557eee20 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -31,8 +31,11 @@ #include "kernels/scaled_attn/acc_value.hpp" #include "kernels/scaled_attn/reduce.hpp" +#include "common/cpu_convert.h" + using namespace InferenceEngine; using namespace InferenceEngine::Extensions::Cpu::XARCH; +using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { @@ -445,7 +448,7 @@ struct MHAKernel { #endif // 2nd token case : only 1 token in query -template +template struct MHASingleToken { PlainTensor m_attn_w; PlainTensor m_temp; @@ -467,8 +470,8 @@ struct MHASingleToken { // alibi // output_emb [B, L1, H*S] void operator()(PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, + PlainTensor& present_key, + PlainTensor& present_value, const PlainTensor& alibi_mask, const PlainTensor& attention_mask, PlainTensor& output_emb, @@ -492,8 +495,10 @@ struct MHASingleToken { parallel_for3d(B, H, kv_len, [&](size_t b, size_t h, size_t pk) { // which batch item should be used at postion pk? auto b_kv = beams ? beams.at({b, pk}) : b; - std::vector as(q_len), bs(q_len); + std::vector as(q_len); + std::vector bs(q_len); std::vector cs(q_len); + for (size_t pq = 0; pq < q_len; pq++) { as[pq] = &query.at({b, h, pq, 0}); bs[pq] = &present_key.at({b_kv, h, pk, 0}, true); @@ -504,7 +509,8 @@ struct MHASingleToken { reinterpret_cast(cs.data()), q_len, S, - precision_of::value); + precision_of::value, + precision_of::value); }); parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { @@ -538,7 +544,7 @@ struct MHASingleToken { size_t b, h, pv; if (start < end) { parallel_it_init(start, b, B, h, H, pv, kv_len); - std::vector vs(q_len * (end - start)); + std::vector vs(q_len * (end - start)); std::vector weights(q_len * (end - start)); std::vector outs(q_len * (end - start)); size_t idx = 0; @@ -558,7 +564,7 @@ struct MHASingleToken { reinterpret_cast(vs.data()), q_len * (end - start), S, - precision_of::value); + precision_of::value); } }); @@ -571,7 +577,7 @@ struct MHASingleToken { } }; -template +template struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor { PlainTensor q_input; // f32[B, H, L1, S] PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] @@ -581,7 +587,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt float scale_input = 0.0f; MHAKernel kernel; - MHASingleToken kernel_single_token; + MHASingleToken kernel_single_token; size_t B, H, L1, L0, S; @@ -599,51 +605,49 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt const std::vector& outputs, const PlainTensor& k_input, const PlainTensor& v_input, - PlainTensor& past_k_output, - PlainTensor& past_v_output) { + PlainTensor& past_k_output, + PlainTensor& past_v_output) { if (config.config.fuse_concat) { k_input.assert_dims({B, 0, L1, S}, true); v_input.assert_dims({B, 0, L1, S}, true); + auto past_k_idx = inputs.size() - 2; auto past_k_mem = inputs[past_k_idx + 0]; L0 = past_k_mem->getStaticDims()[2]; // k,v may support multiquery auto Hk = past_k_mem->getStaticDims()[1]; // [S, B, L0, S] - past_k_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[1]->getData())); - past_v_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[2]->getData())); + past_k_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[1]->getData())); + past_v_output.resize({L0 + L1, B, Hk, S}, static_cast(outputs[2]->getData())); past_k_output = past_k_output.permute({1, 2, 0, 3}); past_v_output = past_v_output.permute({1, 2, 0, 3}); parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) { - std::memcpy(&past_k_output.at({b, h, m + L0, 0}), - &k_input.at({b, h, m, 0}), - S * sizeof(T)); - std::memcpy(&past_v_output.at({b, h, m + L0, 0}), - &v_input.at({b, h, m, 0}), - S * sizeof(T)); + cpu_convert(&k_input.at({b, h, m, 0}), &past_k_output.at({b, h, m + L0, 0}), precision_of::value, precision_of::value, S); + cpu_convert(&v_input.at({b, h, m, 0}), &past_v_output.at({b, h, m + L0, 0}), precision_of::value, precision_of::value, S); }); if (!config.skipPastKVCopy) { - PlainTensor past_k_input, past_v_input; - past_k_input.resize({L0, B, Hk, S}, static_cast(past_k_mem->getData())); - past_v_input.resize({L0, B, Hk, S}, static_cast(inputs[past_k_idx + 1]->getData())); + PlainTensor past_k_input, past_v_input; + past_k_input.resize({L0, B, Hk, S}, static_cast(past_k_mem->getData())); + past_v_input.resize({L0, B, Hk, S}, static_cast(inputs[past_k_idx + 1]->getData())); past_k_input = past_k_input.permute({1, 2, 0, 3}); past_v_input = past_v_input.permute({1, 2, 0, 3}); parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) { std::memcpy(&past_k_output.at({b, h, m, 0}), &past_k_input.at({b, h, m, 0}), - S * sizeof(T)); + S * sizeof(T2)); std::memcpy(&past_v_output.at({b, h, m, 0}), &past_v_input.at({b, h, m, 0}), - S * sizeof(T)); + S * sizeof(T2)); }); } } else { // k,v inputs are already concatenated + OPENVINO_ASSERT(precision_of::value == precision_of::value); L0 = k_input.size(2) - L1; k_input.assert_dims({B, 0, L0 + L1, S}, true); v_input.assert_dims({B, 0, L0 + L1, S}, true); - past_k_output = k_input; - past_v_output = v_input; + past_k_output = static_cast>(k_input); + past_v_output = static_cast>(v_input); } } @@ -677,7 +681,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt L1 = q_input.size(2); S = q_input.size(-1); - PlainTensor present_key, present_value; + PlainTensor present_key, present_value; concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); ov::intel_cpu::PlainTensor output_emb(outputs[0]); @@ -748,15 +752,30 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { return; auto rtPrecision = getOriginalInputPrecisionAtPort(0); + bool enable_fp16_kvcache = true; + const char* enable_fp16 = std::getenv("OV_ENABLE_SDPA_KVCACHE_FP16"); + if (enable_fp16 && std::atoi(enable_fp16) > 0) { + enable_fp16_kvcache = true; + } + + auto kvCachePrecision = (m_config.config.fuse_concat && enable_fp16_kvcache && mayiuse(cpu_isa_t::avx2)) ? ov::element::f16 : rtPrecision; + std::cout << "===================== kvPrecision = " << kvCachePrecision << ", rtPrecision = " << rtPrecision << std::endl; + if (rtPrecision == ov::element::bf16) { - m_executor = std::make_shared>(m_config); + m_executor = std::make_shared>(m_config); } else { // only support bf16/f32 rtPrecision = ov::element::f32; #ifdef OV_CPU_WITH_MLAS - m_executor = std::make_shared>(m_config); + if (kvCachePrecision == ov::element::f16) + m_executor = std::make_shared>(m_config); + else + m_executor = std::make_shared>(m_config); #else - m_executor = std::make_shared>(m_config); + if (kvCachePrecision == ov::element::f16) + m_executor = std::make_shared>(m_config); + else + m_executor = std::make_shared>(m_config); #endif } NodeConfig config; @@ -788,9 +807,9 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { } if (m_config.config.fuse_concat) { config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); } config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -798,10 +817,10 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { if (m_config.config.fuse_concat) { config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc( - rtPrecision, getOutputShapeAtPort(1))); + kvCachePrecision, getOutputShapeAtPort(1))); config.outConfs[1].inPlace(orginSDPInputNumber + 0); config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc( - rtPrecision, getOutputShapeAtPort(2))); + kvCachePrecision, getOutputShapeAtPort(2))); config.outConfs[2].inPlace(orginSDPInputNumber + 1); } supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 716362f0043d05..b506857a580fa2 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -53,7 +53,7 @@ class ScaledDotProductAttention : public Node { struct Config m_config; std::shared_ptr m_executor; - template struct AttentionExecutor; + template struct AttentionExecutor; }; } // namespace node diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp index bf5dac2d822ba6..796a7186daed48 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp @@ -73,7 +73,7 @@ class ConcatSDPTest : public testing::WithParamInterface, v bool hasShapeOf; std::tie(inType, inputShapes, hasShapeOf) = this->GetParam(); targetDevice = ov::test::utils::DEVICE_CPU; - rel_threshold = 1e-4f; + rel_threshold = 1e-2f; if (inType == ElementType::bf16) { configuration.insert({"ENFORCE_BF16", "YES"}); rel_threshold = 0.01f; diff --git a/src/plugins/intel_cpu/tests/unit/kernel/scaled_attn_test.cpp b/src/plugins/intel_cpu/tests/unit/kernel/scaled_attn_test.cpp new file mode 100644 index 00000000000000..e69de29bb2d1d6