From 0a7d1d770ff37150e25d5cb94a213a5f63cbba6c Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 7 Dec 2023 07:24:49 +0800 Subject: [PATCH] [CPU] Optimize ScaledDotProductAttention performance (#21412) --- .../compile_flags/os_flags.cmake | 4 +- src/plugins/intel_cpu/CMakeLists.txt | 19 +- src/plugins/intel_cpu/src/cpu_types.cpp | 2 +- src/plugins/intel_cpu/src/extension.cpp | 4 +- .../nodes/kernels/scaled_attn/acc_value.cpp | 71 --- .../nodes/kernels/scaled_attn/attn_memcpy.cpp | 100 +++ .../{acc_value.hpp => attn_memcpy.hpp} | 7 +- .../src/nodes/kernels/scaled_attn/common.hpp | 23 + .../nodes/kernels/scaled_attn/dot_product.cpp | 75 --- .../nodes/kernels/scaled_attn/dot_product.hpp | 21 - .../kernels/scaled_attn/mha_single_token.cpp | 288 +++++++++ .../kernels/scaled_attn/mha_single_token.hpp | 34 ++ .../src/nodes/kernels/scaled_attn/reduce.cpp | 80 --- .../src/nodes/kernels/scaled_attn/reduce.hpp | 21 - .../src/nodes/kernels/scaled_attn/softmax.cpp | 544 +---------------- .../src/nodes/kernels/scaled_attn/softmax.hpp | 1 + .../kernels/scaled_attn/softmax_kernel.hpp | 576 ++++++++++++++++++ src/plugins/intel_cpu/src/nodes/rope.cpp | 48 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 328 ++++------ src/plugins/intel_cpu/src/nodes/scaled_attn.h | 6 +- .../cpu_opset/common/op/{sdp.cpp => sdpa.cpp} | 18 +- .../cpu_opset/common/op/{sdp.hpp => sdpa.hpp} | 8 +- ...dp_fusion.cpp => stateful_sdpa_fusion.cpp} | 12 +- ...dp_fusion.hpp => stateful_sdpa_fusion.hpp} | 6 +- .../transformation_pipeline.cpp | 4 +- .../intel_cpu/src/utils/plain_tensor.hpp | 121 ++-- .../subgraph_tests/src/concat_sdp.cpp | 3 +- .../tests/unit/graph/scaled_attn.cpp | 29 +- .../transformations/state_concat_sdpa.cpp | 12 +- 29 files changed, 1291 insertions(+), 1174 deletions(-) delete mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp rename src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/{acc_value.hpp => attn_memcpy.hpp} (57%) delete mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp delete mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp delete mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.cpp delete mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp rename src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/{sdp.cpp => sdpa.cpp} (71%) rename src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/{sdp.hpp => sdpa.hpp} (78%) rename src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/{stateful_sdp_fusion.cpp => stateful_sdpa_fusion.cpp} (94%) rename src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/{stateful_sdp_fusion.hpp => stateful_sdpa_fusion.hpp} (64%) diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index c0c878e0183eb0..2c621d93425f4b 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -125,7 +125,7 @@ macro(ov_avx2_optimization_flags flags) set(${flags} -xCORE-AVX2) endif() elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX) - set(${flags} -mavx2 -mfma) + set(${flags} -mavx2 -mfma -mf16c) else() message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") endif() @@ -147,7 +147,7 @@ macro(ov_avx512_optimization_flags flags) set(${flags} -xCOMMON-AVX512) endif() elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX) - set(${flags} -mavx512f -mfma) + set(${flags} -mavx512f -mfma -mf16c) else() message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") endif() diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index d65c8661fcef0d..efa25c321f9b1c 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -161,23 +161,16 @@ cross_compiled_file(${TARGET_NAME} ) cross_compiled_file(${TARGET_NAME} ARCH AVX512F AVX2 ANY - src/nodes/kernels/scaled_attn/dot_product.cpp - API src/nodes/kernels/scaled_attn/dot_product.hpp - NAME attn_dot_products + src/nodes/kernels/scaled_attn/mha_single_token.cpp + API src/nodes/kernels/scaled_attn/mha_single_token.hpp + NAME mha_single_token NAMESPACE InferenceEngine::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} ARCH AVX512F AVX2 ANY - src/nodes/kernels/scaled_attn/acc_value.cpp - API src/nodes/kernels/scaled_attn/acc_value.hpp - NAME attn_acc_values - NAMESPACE InferenceEngine::Extensions::Cpu::XARCH -) -cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 ANY - src/nodes/kernels/scaled_attn/reduce.cpp - API src/nodes/kernels/scaled_attn/reduce.hpp - NAME attn_reduce + src/nodes/kernels/scaled_attn/attn_memcpy.cpp + API src/nodes/kernels/scaled_attn/attn_memcpy.hpp + NAME attn_memcpy NAMESPACE InferenceEngine::Extensions::Cpu::XARCH ) # system dependencies must go last diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index 00896c6c8a276b..0175519b87a8e3 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -215,7 +215,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { { "Unique", Type::Unique}, { "Ngram", Type::Ngram}, { "ScaledDotProductAttention", Type::ScaledDotProductAttention}, - { "ScaledDotProductAttentionStub", Type::ScaledDotProductAttention}, + { "ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention}, { "RoPE", Type::RoPE}, }; return type_to_name_tbl; diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index d9fe51c5151aa4..373c4b90e8cb93 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -6,7 +6,7 @@ #include "transformations/cpu_opset/common/op/fully_connected.hpp" #include "transformations/cpu_opset/common/op/leaky_relu.hpp" #include "transformations/cpu_opset/common/op/power_static.hpp" -#include "transformations/cpu_opset/common/op/sdp.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/cpu_opset/common/op/ngram.hpp" #include "transformations/cpu_opset/x64/op/mha.hpp" @@ -61,7 +61,7 @@ std::map Extension::getOpSets() { NGRAPH_OP(NgramNode, ov::intel_cpu) NGRAPH_OP_X64(MHANode, ov::intel_cpu) NGRAPH_OP_X64(InteractionNode, ov::intel_cpu) - NGRAPH_OP_X64(ScaledDotProductAttentionStub, ov::intel_cpu) + NGRAPH_OP_X64(ScaledDotProductAttentionWithKVCache, ov::intel_cpu) #undef NGRAPH_OP return opset; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp deleted file mode 100644 index 994fb55e971525..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include - -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) -# include -#endif - -#include "openvino/core/type/bfloat16.hpp" -#include "common.hpp" -#include "acc_value.hpp" - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { -namespace XARCH { - -template -void attn_acc_value_inner(float* out, float weight, T* v, size_t S) { - size_t i = 0; -#if defined(HAVE_AVX512F) - auto attn_w_vec_fp32 = _mm512_set1_ps(weight); - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_value = mm512_uni_loadu_ps(v + i); - auto v_out = mm512_uni_loadu_ps(out + i); - v_out = _mm512_fmadd_ps(attn_w_vec_fp32, v_value, v_out); - _mm512_storeu_ps(out + i, v_out); - } -#elif defined(HAVE_AVX2) - auto attn_w_vec_fp32 = _mm256_set1_ps(weight); - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_value = mm256_uni_loadu_ps(v + i); - auto v_out = mm256_uni_loadu_ps(out + i); - v_out = _mm256_fmadd_ps(attn_w_vec_fp32, v_value, v_out); - mm256_uni_storeu_ps(out + i, v_out); - } -#endif - for (; i < S; i++) { - out[i] += weight * v[i]; - } -} - -void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, size_t vec_len, ov::element::Type input_precision) { - if (input_precision == ov::element::f32) { - for (size_t i = 0; i < vec_num; i++) { - auto out_ptr = outs[i]; - auto v_ptr = static_cast(vs[i]); - attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len); - } - } else { - assert(input_precision == ov::element::bf16); - for (size_t i = 0; i < vec_num; i++) { - auto out_ptr = outs[i]; - auto v_ptr = static_cast(vs[i]); - attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len); - } - } -} - -} // namespace XARCH -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp new file mode 100644 index 00000000000000..6c2b35b94f49ed --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include +#include +#include +#include +#include + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/parallel.hpp" +#include "common.hpp" +#include "attn_memcpy.hpp" + +namespace InferenceEngine { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +using namespace ov; + +// float16 <- float +template +void attn_copy(TA* a, TB* b, size_t n) { + size_t i = 0; +#if defined(HAVE_AVX512F) + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto vb = mm512_uni_loadu_ps(b + i); + mm512_uni_storeu_ps(a + i, vb); + } +#elif defined(HAVE_AVX2) + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto vb = mm256_uni_loadu_ps(b + i); + mm256_uni_storeu_ps(a + i, vb); + } +#endif + for (; i < n; i++) { + a[i] = b[i]; + } +} + +template +void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output) { + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; + parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + attn_copy(&past_k_output.at({b, h, m, 0}), + &k_input.at({b, h, m, 0}), + S); + attn_copy(&past_v_output.at({b, h, m, 0}), + &v_input.at({b, h, m, 0}), + S); + }); +} + +template +void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output) { + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; + parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + memcpy(&past_k_output.at({b, h, m, 0}), + &k_input.at({b, h, m, 0}), + S * sizeof(T)); + memcpy(&past_v_output.at({b, h, m, 0}), + &v_input.at({b, h, m, 0}), + S * sizeof(T)); + }); +} + +void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output) { + if (past_k_output.get_precision() == k_input.get_precision()) { + if (past_k_output.get_precision() == ov::element::bf16) { + attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output); + } else { + assert(past_k_output.get_precision() == ov::element::f16); + attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output); + } + } else if (past_k_output.get_precision() == ov::element::f16) { + attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output); + } else { + attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output); + } +} +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp similarity index 57% rename from src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.hpp rename to src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp index c4ef60b3028ae0..ad1c5c69db098c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp @@ -1,19 +1,24 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#pragma once #include #include #include #include #include +#include "utils/plain_tensor.hpp" namespace InferenceEngine { namespace Extensions { namespace Cpu { namespace XARCH { -void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, size_t vec_len, ov::element::Type input_precision); +void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output); } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index b9d4ecbba2bf8c..7b729b274c3be6 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -1,12 +1,16 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#pragma once #include #include #include #include +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" + namespace InferenceEngine { namespace Extensions { namespace Cpu { @@ -55,6 +59,15 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); } + + inline __m512 mm512_uni_loadu_ps(ov::float16* a) { + auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a)); + return _mm512_cvtph_ps(vec_f16); + } + inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) { + __m256i vec_f16 = _mm512_cvtps_ph(v, 0); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16); + } #endif #ifdef HAVE_AVX2 @@ -87,6 +100,16 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); _mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o); } + inline __m256 mm256_uni_loadu_ps(ov::float16* a) { + auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a)); + auto o = _mm256_cvtph_ps(vec_f16); + return o; + } + inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) { + __m128i vec_f16 = _mm256_cvtps_ph(v, 0); + _mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16); + } + inline void hsum(__m256& x) { __m256 y; // x: 0 1 2 3 4 5 6 7 y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp deleted file mode 100644 index 3a963b68842d16..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include - -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) -# include -#endif - -#include "openvino/core/type/bfloat16.hpp" -#include "common.hpp" -#include "dot_product.hpp" - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { -namespace XARCH { - -template -float dot_product_inner(T* a, T* b, size_t n) { - size_t i = 0; - float sum = 0.0f; -#if defined(HAVE_AVX512F) - auto vsum = _mm512_setzero_ps(); - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb = mm512_uni_loadu_ps(b + i); - vsum = _mm512_fmadd_ps(va, vb, vsum); - } - sum = _mm512_reduce_add_ps(vsum); -#elif defined(HAVE_AVX2) - auto vsum = _mm256_set1_ps(0.0f); - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb = mm256_uni_loadu_ps(b + i); - vsum = _mm256_fmadd_ps(va, vb, vsum); - } - hsum(vsum); - sum = _mm256_cvtss_f32(vsum); -#endif - for (; i < n; i++) { - sum += a[i] * b[i]; - } - return sum; -} - -void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision) { - if (input_precision == ov::element::f32) { - for (size_t i = 0; i < vec_num; i++) { - auto a_ptr = static_cast(a[i]); - auto b_ptr = static_cast(b[i]); - auto c_ptr = static_cast(c[i]); - c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len); - } - } else { - for (size_t i = 0; i < vec_num; i++) { - auto a_ptr = static_cast(a[i]); - auto b_ptr = static_cast(b[i]); - auto c_ptr = static_cast(c[i]); - c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len); - } - } -} - -} // namespace XARCH -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp deleted file mode 100644 index 161fd9b890cda2..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/dot_product.hpp +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include -#include -#include - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { -namespace XARCH { - -void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision); - -} // namespace XARCH -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp new file mode 100644 index 00000000000000..8ac1aca6c1467e --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -0,0 +1,288 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include +#include +#include +#include +#include + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/parallel.hpp" +#include "mha_single_token.hpp" +#include "common.hpp" +#include "softmax_kernel.hpp" + +namespace InferenceEngine { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +using namespace ov; + +template +void attn_acc_value(float* out, float weight, T* v, size_t S) { + size_t i = 0; +#if defined(HAVE_AVX512F) + auto attn_w_vec_fp32 = _mm512_set1_ps(weight); + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_value = mm512_uni_loadu_ps(v + i); + auto v_out = mm512_uni_loadu_ps(out + i); + v_out = _mm512_fmadd_ps(attn_w_vec_fp32, v_value, v_out); + _mm512_storeu_ps(out + i, v_out); + } +#elif defined(HAVE_AVX2) + auto attn_w_vec_fp32 = _mm256_set1_ps(weight); + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_value = mm256_uni_loadu_ps(v + i); + auto v_out = mm256_uni_loadu_ps(out + i); + v_out = _mm256_fmadd_ps(attn_w_vec_fp32, v_value, v_out); + mm256_uni_storeu_ps(out + i, v_out); + } +#endif + for (; i < S; i++) { + out[i] += weight * v[i]; + } +} + +template +float dot_product(TA* a, TB* b, size_t n) { + size_t i = 0; + float sum = 0.0f; +#if defined(HAVE_AVX512F) + auto vsum = _mm512_setzero_ps(); + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + auto vb = mm512_uni_loadu_ps(b + i); + vsum = _mm512_fmadd_ps(va, vb, vsum); + } + sum = _mm512_reduce_add_ps(vsum); +#elif defined(HAVE_AVX2) + auto vsum = _mm256_set1_ps(0.0f); + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + auto vb = mm256_uni_loadu_ps(b + i); + vsum = _mm256_fmadd_ps(va, vb, vsum); + } + hsum(vsum); + sum = _mm256_cvtss_f32(vsum); +#endif + for (; i < n; i++) { + sum += a[i] * b[i]; + } + return sum; +} + +template +void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { + size_t i = 0; +#if defined(HAVE_AVX512F) + for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { + auto* src = temp + i; + auto result_vec_fp32 = _mm512_setzero_ps(); + for (size_t m = 0; m < M; m++) { + //auto* temp = &m_temp.at({ithr, b, pq, h, 0}); + auto o_vec_fp32 = _mm512_loadu_ps(src); + result_vec_fp32 = _mm512_add_ps(result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + // save to bf16 + mm512_uni_storeu_ps(dst + i, result_vec_fp32); + } +#elif defined(HAVE_AVX2) + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto* src = temp + i; + auto result_vec_fp32 = _mm256_set1_ps(0.0f); + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = mm256_uni_loadu_ps(src); + result_vec_fp32 = _mm256_add_ps(result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + mm256_uni_storeu_ps(dst + i, result_vec_fp32); + } +#endif + for (; i < S; i++) { + auto* src = temp + i; + float sum = 0.0f; + // sum result from all threads partition + for (size_t m = 0; m < M; m++) { + sum += src[0]; + src += temp_stride; + } + dst[i] = sum; + } +} + +template +void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& present_key, + const ov::intel_cpu::PlainTensor& present_value, + const ov::intel_cpu::PlainTensor& alibi_mask, + const ov::intel_cpu::PlainTensor& attention_mask, + const ov::intel_cpu::PlainTensor& beams, + ov::intel_cpu::PlainTensor& output_emb, + ov::intel_cpu::PlainTensor& buf_attn_w, + ov::intel_cpu::PlainTensor& buf_attn_score, + bool has_out_transpose, + bool auto_causal, + float d_scale) { + ov::intel_cpu::PlainTensor causal_mask; + bool select_nfltmax_at_0 = false; + auto B = query.size(0); + auto H = query.size(1); + auto q_len = query.size(2); + auto S = query.size(3); + auto kv_len = present_key.size(2); + + if (d_scale == 0.0f) + d_scale = 1.0f / sqrt(S); + + // use per-token kernel, for each k,v token + // attn mask is a matrix of q_len(kv_len) + buf_attn_w.resize({B, H, q_len, kv_len}); + + bool is_abcd = present_key.stride(1) >= present_key.stride(2); + size_t dim0 = is_abcd ? B : kv_len; + size_t dim1 = is_abcd ? H : B; + size_t dim2 = is_abcd ? kv_len : H; + + parallel_for3d(dim0, dim1, dim2, [&](size_t d0, size_t d1, size_t d2) { + size_t b = is_abcd ? d0 : d1; + size_t h = is_abcd ? d1 : d2; + size_t pk = is_abcd ? d2 : d0; + + // which batch item should be used at postion pk? + auto b_kv = beams ? beams.at({b, pk}) : b; + for (size_t pq = 0; pq < q_len; pq++) { + buf_attn_w.at({b, h, pq, pk}) = dot_product(&query.at({b, h, pq, 0}), + &present_key.at({b_kv, h, pk, 0}, true), + S); + } + }); + + parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { + // apply attention mask & sofmax + auto ncausal = auto_causal ? (kv_len - q_len + pq + 1) : kv_len; + float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; + float* attn_mask_ptr = attention_mask ? &attention_mask.at({b, h, pq, 0}, true) : nullptr; + uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; + attn_softmax_kernel(&buf_attn_w.at({b, h, pq, 0}), + &buf_attn_w.at({b, h, pq, 0}), + d_scale, + alibi_ptr, + attn_mask_ptr, + cmask_ptr, + select_nfltmax_at_0, + ncausal, + kv_len, + ov::element::f32); + }); + + // attn_w * V + auto nthr = parallel_get_max_threads(); + buf_attn_score.resize({static_cast(nthr), B, q_len, H, S}); + // buf_attn_w {B, H, q_len, kv_len} + parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { + size_t start{0}, end{0}; + splitter(B * H * kv_len, nthr, ithr, start, end); + + memset(&buf_attn_score.at({ithr, 0, 0, 0, 0}), 0, buf_attn_score.stride(0) * sizeof(float)); + + size_t b, h, pv; + if (start < end) { + if (is_abcd) + parallel_it_init(start, b, B, h, H, pv, kv_len); + else + parallel_it_init(start, pv, kv_len, b, B, h, H); + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.at({b, pv}) : b; + auto* v = &present_value.at({b_kv, h, pv, 0}, true); + for (size_t pq = 0; pq < q_len; pq++) { + attn_acc_value(&buf_attn_score.at({ithr, b, pq, h, 0}), + buf_attn_w.at({b, h, pq, pv}), + v, + S); + } + if (is_abcd) + parallel_it_step(b, B, h, H, pv, kv_len); + else + parallel_it_step(pv, kv_len, b, B, h, H); + } + } + }); + + parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { + auto* temp = &buf_attn_score.at({0, b, pq, h, 0}); + size_t temp_stride = buf_attn_score.stride(0); + auto* dst = has_out_transpose ? &output_emb.at({b, pq, h * S}) : &output_emb.at({b, h, pq}); + attn_reduce(dst, temp, nthr, S, temp_stride); + }); +} + +void mha_single_token(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& present_key, + const ov::intel_cpu::PlainTensor& present_value, + const ov::intel_cpu::PlainTensor& alibi_mask, + const ov::intel_cpu::PlainTensor& attention_mask, + const ov::intel_cpu::PlainTensor& beams, + ov::intel_cpu::PlainTensor& output_emb, + ov::intel_cpu::PlainTensor& buf_attn_w, + ov::intel_cpu::PlainTensor& buf_attn_score, + bool has_out_transpose, + bool auto_causal, + float d_scale) { + if (query.get_precision() == ov::element::bf16) { + mha_single_token_kernel(query, + present_key, + present_value, + alibi_mask, + attention_mask, + beams, + output_emb, + buf_attn_w, + buf_attn_score, + has_out_transpose, + auto_causal, + d_scale); + } else if (query.get_precision() == ov::element::f32) { + if (present_key.get_precision() == ov::element::f16) { + mha_single_token_kernel(query, + present_key, + present_value, + alibi_mask, + attention_mask, + beams, + output_emb, + buf_attn_w, + buf_attn_score, + has_out_transpose, + auto_causal, + d_scale); + } else { + mha_single_token_kernel(query, + present_key, + present_value, + alibi_mask, + attention_mask, + beams, + output_emb, + buf_attn_w, + buf_attn_score, + has_out_transpose, + auto_causal, + d_scale); + } + } else { + OPENVINO_THROW("Unsupported precision: ", query.get_precision()); + } +} +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp new file mode 100644 index 00000000000000..543f7f1d9217a0 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp @@ -0,0 +1,34 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include +#include "utils/plain_tensor.hpp" + +namespace InferenceEngine { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +void mha_single_token(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& present_key, + const ov::intel_cpu::PlainTensor& present_value, + const ov::intel_cpu::PlainTensor& alibi_mask, + const ov::intel_cpu::PlainTensor& attention_mask, + const ov::intel_cpu::PlainTensor& beams, + ov::intel_cpu::PlainTensor& output_emb, + ov::intel_cpu::PlainTensor& buf_attn_w, + ov::intel_cpu::PlainTensor& buf_attn_score, + bool has_out_transpose, + bool auto_causal, + float d_scale); + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.cpp deleted file mode 100644 index ad5e502dfa8907..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include - -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) - #include -#endif - -#include "openvino/core/type/bfloat16.hpp" -#include "common.hpp" -#include "reduce.hpp" - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { -namespace XARCH { - -template -void attn_reduce_inner(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { - size_t i = 0; -#if defined(HAVE_AVX512F) - for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { - auto* src = temp + i; - auto result_vec_fp32 = _mm512_setzero_ps(); - for (size_t m = 0; m < M; m++) { - //auto* temp = &m_temp.at({ithr, b, pq, h, 0}); - auto o_vec_fp32 = _mm512_loadu_ps(src); - result_vec_fp32 = _mm512_add_ps(result_vec_fp32, o_vec_fp32); - src += temp_stride; - } - // save to bf16 - mm512_uni_storeu_ps(dst + i, result_vec_fp32); - } -#elif defined(HAVE_AVX2) - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto* src = temp + i; - auto result_vec_fp32 = _mm256_set1_ps(0.0f); - for (size_t m = 0; m < M; m++) { - auto o_vec_fp32 = mm256_uni_loadu_ps(src); - result_vec_fp32 = _mm256_add_ps(result_vec_fp32, o_vec_fp32); - src += temp_stride; - } - mm256_uni_storeu_ps(dst + i, result_vec_fp32); - } -#endif - for (; i < S; i++) { - auto* src = temp + i; - float sum = 0.0f; - // sum result from all threads partition - for (size_t m = 0; m < M; m++) { - sum += src[0]; - src += temp_stride; - } - dst[i] = sum; - } -} - -void attn_reduce(void* dst, float* temp, size_t M, size_t S, size_t temp_stride, ov::element::Type input_precision) { - if (input_precision == ov::element::f32) { - auto dst_ptr = static_cast(dst); - attn_reduce_inner(dst_ptr, temp, M, S, temp_stride); - } else { - assert(input_precision == ov::element::bf16); - auto dst_ptr = static_cast(dst); - attn_reduce_inner(dst_ptr, temp, M, S, temp_stride); - } -} - -} // namespace XARCH -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.hpp deleted file mode 100644 index baa9b6d1835f1d..00000000000000 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/reduce.hpp +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include -#include -#include - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { -namespace XARCH { - -void attn_reduce(void* dst, float* temp, size_t M, size_t S, size_t temp_stride, ov::element::Type input_precision); - -} // namespace XARCH -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp index 5c2bea26b71556..7ebec7f9bae3ef 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp @@ -15,522 +15,13 @@ #include "openvino/core/type/bfloat16.hpp" #include "softmax.hpp" +#include "softmax_kernel.hpp" #include "common.hpp" namespace InferenceEngine { namespace Extensions { namespace Cpu { namespace XARCH { -#if defined(HAVE_AVX2) -inline __m256i get_mask(int N7) { - static __m256i mask[] = { - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), - }; - return _mm256_loadu_si256(&mask[N7]); -} - -inline void hmax(__m256& x) { - __m256 y; // x: 0 1 2 3 4 5 6 7 - y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 - x = _mm256_max_ps(x, y); // X: 01 12 23 30 45 56 67 74 - y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 - x = _mm256_max_ps(x, y); // x: 0123 x x x 4567 x x x - y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x - x = _mm256_max_ps(x, y); // x: 01234567 x x x x x x x -} - -inline void exp_ps_avx2(__m256& src) { - static __m256 exp_ln_flt_min_f = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); // log(FLT_MIN) - static __m256 exp_ln_flt_max_f = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); // log(FLT_MAX) - static __m256 exp_log2ef = _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) - static __m256 half = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f000000)); // 0.5f - static __m256 ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) - static __m256 one = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f800000)); // 1.0f - static __m256i exponent_bias = _mm256_set1_epi32(0x0000007f); // 127 - static constexpr int n_mantissa_bits = 23; - static __m256 exp_pol1 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f - static __m256 exp_pol2 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3efffee3)); // p2 = 0.499991506f - static __m256 exp_pol3 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f - static __m256 exp_pol4 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f - static __m256 exp_pol5 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f - static __m256 two = _mm256_castsi256_ps(_mm256_set1_epi32(0x40000000)); // 2 - // exp(x) = - // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem - // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression - - // get mask of values lower than log(FLT_MIN) to zero them in the output - auto zero_mask = _mm256_cmp_ps(src, exp_ln_flt_min_f, _CMP_LT_OS); - - // clip src - src = _mm256_min_ps(src, exp_ln_flt_max_f); - src = _mm256_max_ps(src, exp_ln_flt_min_f); - - // aux1 : r - auto aux1 = src; - - // calculate exp(x) - // fx = x * log2(e) + 0.5 - src = _mm256_mul_ps(src, exp_log2ef); - src = _mm256_add_ps(src, half); - - // tmp = floorf(fx) - src = _mm256_floor_ps(src); - - // aux1 = x - fx * ln2 - aux1 = _mm256_fnmadd_ps(src, ln2f, aux1); - - // We do not count 2^n here, because n can reach 128 and 2^128 is not - // representable by fp32, so to get around this problem, instead of computing - // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 - // and 2 are numbers representable in fp32. - - // compute 2^(n-1) - src = _mm256_sub_ps(src, one); - auto aux2_i = _mm256_cvtps_epi32(src); - aux2_i = _mm256_add_epi32(aux2_i, exponent_bias); - aux2_i = _mm256_slli_epi32(aux2_i, n_mantissa_bits); - - // set zeroes at those points which were < log(FLT_MIN) - auto zero = _mm256_setzero_ps(); - auto aux2 = _mm256_blendv_ps(_mm256_castsi256_ps(aux2_i), zero, zero_mask); - - // compute polynomial - src = exp_pol5; - src = _mm256_fmadd_ps(src, aux1, exp_pol4); - src = _mm256_fmadd_ps(src, aux1, exp_pol3); - src = _mm256_fmadd_ps(src, aux1, exp_pol2); - src = _mm256_fmadd_ps(src, aux1, exp_pol1); - src = _mm256_fmadd_ps(src, aux1, one); - - // y = y * 2^n - src = _mm256_mul_ps(src, aux2); - src = _mm256_mul_ps(src, two); -} -#endif - -inline void scale_add_reduce_max(float* a, const float scale, const float* b, const size_t size, float& max) { -#if defined(HAVE_AVX512F) - auto v_max = _mm512_set1_ps(std::numeric_limits::lowest()); - auto v_scale = _mm512_set1_ps(scale); - auto v_a = v_max; - auto v_b = v_max; - size_t i = 0; - // process vector body - while (i + vec_len_f32_avx512 <= size) { - v_a = _mm512_loadu_ps(a + i); - v_b = _mm512_loadu_ps(b + i); - v_a = _mm512_fmadd_ps(v_a, v_scale, v_b); - v_max = _mm512_max_ps(v_max, v_a); - _mm512_storeu_ps(a + i, v_a); - i += vec_len_f32_avx512; - } - - // process tails - if (i < size) { - __mmask16 mask = (1 << (size - i)) - 1; - v_a = _mm512_maskz_loadu_ps(mask, a + i); - v_b = _mm512_maskz_loadu_ps(mask, b + i); - v_a = _mm512_fmadd_ps(v_a, v_scale, v_b); - v_max = _mm512_mask_max_ps(v_max, mask, v_a, v_max); - _mm512_mask_storeu_ps(a + i, mask, v_a); - } - - max = _mm512_reduce_max_ps(v_max); -#elif defined(HAVE_AVX2) - auto v_max = _mm256_set1_ps(std::numeric_limits::lowest()); - auto v_scale = _mm256_set1_ps(scale); - auto v_a = v_max; - auto v_b = v_max; - size_t i = 0; - // process vector body - while (i + vec_len_f32_avx2 <= size) { - v_a = _mm256_loadu_ps(a + i); - v_b = _mm256_loadu_ps(b + i); - v_a = _mm256_fmadd_ps(v_a, v_scale, v_b); - v_max = _mm256_max_ps(v_max, v_a); - _mm256_storeu_ps(a + i, v_a); - i += vec_len_f32_avx2; - } - - // process tails - if (i < size) { - auto mask = get_mask(size - i); - v_a = _mm256_maskload_ps(a + i, mask); - v_b = _mm256_maskload_ps(b + i, mask); - v_a = _mm256_fmadd_ps(v_a, v_scale, v_b); - v_a = _mm256_blendv_ps(v_max, v_a, _mm256_castsi256_ps(mask)); - v_max = _mm256_max_ps(v_max, v_a); - _mm256_maskstore_ps(a + i, mask, v_a); - } - hmax(v_max); - max = _mm256_cvtss_f32(v_max); -#else - for (size_t i = 0; i < size; i++) { - a[i] *= scale; - a[i] += b[i]; - max = a[i] > max ? a[i] : max; - } -#endif -} -template -inline void scale_add2_reduce_max(float* a, - float scale, - const float* alibi, - const float* attn_mask, - const uint8_t* causal_mask, - bool select_nfltmax_at_0, // true: 0 in mask set -FLT_MAX - size_t size, - float& max) { -#if defined(HAVE_AVX512F) - auto v_max = _mm512_set1_ps(std::numeric_limits::lowest()); - auto v_scale = _mm512_set1_ps(scale); - auto v_a = v_max; - size_t i = 0; - auto v_zeroi32 = _mm512_setzero_epi32(); - auto v_nfltmax = _mm512_set1_ps(-FLT_MAX); - auto kmask_xor = _cvtu32_mask16(select_nfltmax_at_0 ? 0xFFFF : 0); - // process vector body - while (i + vec_len_f32_avx512 <= size) { - v_a = _mm512_loadu_ps(a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - - if (has_alibi) { - auto v_mask = _mm512_loadu_ps(alibi + i); - v_a = _mm512_add_ps(v_a, v_mask); - } - - if (has_attn_mask) { - auto v_mask = _mm512_loadu_ps(attn_mask + i); - v_a = _mm512_add_ps(v_a, v_mask); - } - - if (has_causal_mask) { - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); - auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); - auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); // !=0 - kmask = _kxor_mask16(kmask, kmask_xor); // reverse, mask at ==0 - v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); // mask => -FLT_MAX - } - v_max = _mm512_max_ps(v_max, v_a); - _mm512_storeu_ps(a + i, v_a); - i += vec_len_f32_avx512; - } - - // process tails - if (i < size) { - __mmask16 mask = (1 << (size - i)) - 1; - v_a = _mm512_maskz_loadu_ps(mask, a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - - if (has_alibi) { - auto v_mask = _mm512_maskz_loadu_ps(mask, alibi + i); - v_a = _mm512_add_ps(v_a, v_mask); - } - - if (has_attn_mask) { - auto v_mask = _mm512_maskz_loadu_ps(mask, attn_mask + i); - v_a = _mm512_add_ps(v_a, v_mask); - } - - if (has_causal_mask) { - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); - auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); - auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); // !=0 - kmask = _kxor_mask16(kmask, kmask_xor); // reverse, mask at ==0 - v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); // mask => -FLT_MAX - } - v_max = _mm512_mask_max_ps(v_max, mask, v_a, v_max); - _mm512_mask_storeu_ps(a + i, mask, v_a); - } - - max = _mm512_reduce_max_ps(v_max); -#elif defined(HAVE_AVX2) - auto v_max = _mm256_set1_ps(std::numeric_limits::lowest()); - auto v_scale = _mm256_set1_ps(scale); - auto v_a = v_max; - auto v_zeroi32 = _mm256_setzero_si256(); - auto v_mask_xor = _mm256_set1_epi32(select_nfltmax_at_0 ? -1 : 0); - auto v_nfltmax = _mm256_set1_ps(-FLT_MAX); - size_t i = 0; - // process vector body - while (i + vec_len_f32_avx2 <= size) { - v_a = _mm256_loadu_ps(a + i); - v_a = _mm256_mul_ps(v_a, v_scale); - - if (has_alibi) { - auto v_mask = _mm256_loadu_ps(alibi + i); - v_a = _mm256_add_ps(v_a, v_mask); - } - - if (has_attn_mask) { - auto v_mask = _mm256_loadu_ps(attn_mask + i); - v_a = _mm256_add_ps(v_a, v_mask); - } - - if (has_causal_mask) { - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); - auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); - v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32); // ==0 - v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor); // reverse, mask at ==0 - v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); // mask => -FLT_MAX - } - - v_max = _mm256_max_ps(v_max, v_a); - _mm256_storeu_ps(a + i, v_a); - i += vec_len_f32_avx2; - } - - // process tails - if (i < size) { - auto mask = get_mask(size - i); - v_a = _mm256_maskload_ps(a + i, mask); - v_a = _mm256_mul_ps(v_a, v_scale); - - if (has_alibi) { - auto v_mask = _mm256_maskload_ps(alibi + i, mask); - v_a = _mm256_add_ps(v_a, v_mask); - } - - if (has_attn_mask) { - auto v_mask = _mm256_maskload_ps(attn_mask + i, mask); - v_a = _mm256_add_ps(v_a, v_mask); - } - - if (has_causal_mask) { - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); - auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); - v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32); // ==0 - v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor); // reverse, mask at ==0 - v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); // mask => -FLT_MAX - } - - v_a = _mm256_blendv_ps(v_max, v_a, _mm256_castsi256_ps(mask)); - v_max = _mm256_max_ps(v_max, v_a); - _mm256_maskstore_ps(a + i, mask, v_a); - } - hmax(v_max); - max = _mm256_cvtss_f32(v_max); -#else - for (size_t i = 0; i < size; i++) { - a[i] *= scale; - if (has_alibi) - a[i] += alibi[i]; - - if (has_attn_mask) - a[i] += attn_mask[i]; - - if (has_causal_mask) { - if (select_nfltmax_at_0) { - if (causal_mask[i] == 0) - a[i] = -FLT_MAX; - } else { - if (causal_mask[i] != 0) - a[i] = -FLT_MAX; - } - } - - max = a[i] > max ? a[i] : max; - } -#endif -} - -#if defined(HAVE_AVX512F) -inline void exp_ps_avx512(__m512& src) { - static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) - static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) - static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) - static __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); // 0.5f - static __m512 ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) - static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f - static __m512i exponent_bias = _mm512_set1_epi32(0x0000007f); // 127 - static constexpr int n_mantissa_bits = 23; - static __m512 exp_pol1 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f - static __m512 exp_pol2 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3efffee3)); // p2 = 0.499991506f - static __m512 exp_pol3 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f - static __m512 exp_pol4 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f - static __m512 exp_pol5 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f - static __m512 two = _mm512_castsi512_ps(_mm512_set1_epi32(0x40000000)); // 2 - // exp(x) = - // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem - // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression - - // get mask of values lower than log(FLT_MIN) to zero them in the output - auto zero_mask = _mm512_cmp_ps_mask(src, exp_ln_flt_min_f, _CMP_LT_OS); - - // clip src - src = _mm512_min_ps(src, exp_ln_flt_max_f); - src = _mm512_max_ps(src, exp_ln_flt_min_f); - - // aux1 : r - auto aux1 = src; - - // calculate exp(x) - // fx = x * log2(e) + 0.5 - src = _mm512_mul_ps(src, exp_log2ef); - src = _mm512_add_ps(src, half); - - // tmp = floorf(fx) - src = _mm512_floor_ps(src); - - // aux1 = x - fx * ln2 - aux1 = _mm512_fnmadd_ps(src, ln2f, aux1); - // We do not count 2^n here, because n can reach 128 and 2^128 is not - // representable by fp32, so to get around this problem, instead of computing - // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 - // and 2 are numbers representable in fp32. - - // compute 2^(n-1) - src = _mm512_sub_ps(src, one); - auto aux2_i = _mm512_cvtps_epi32(src); - aux2_i = _mm512_add_epi32(aux2_i, exponent_bias); - aux2_i = _mm512_slli_epi32(aux2_i, n_mantissa_bits); - - // set zeroes at those points which were < log(FLT_MIN) - auto zero = _mm512_setzero_ps(); - auto aux2 = _mm512_mask_blend_ps(zero_mask, _mm512_castsi512_ps(aux2_i), zero); - - // compute polynomial - src = exp_pol5; - src = _mm512_fmadd_ps(src, aux1, exp_pol4); - src = _mm512_fmadd_ps(src, aux1, exp_pol3); - src = _mm512_fmadd_ps(src, aux1, exp_pol2); - src = _mm512_fmadd_ps(src, aux1, exp_pol1); - src = _mm512_fmadd_ps(src, aux1, one); - - // y = y * 2^n - src = _mm512_mul_ps(src, aux2); - src = _mm512_mul_ps(src, two); -} -#endif - -inline void exp_reduce_sum(float* a, const float max, const size_t size, float& sum) { -#if defined(HAVE_AVX512F) - size_t i = 0; - __m512 v_a; - auto v_max = _mm512_set1_ps(max); - auto v_sum = _mm512_set1_ps(0.0f); - while (i + vec_len_f32_avx512 <= size) { - v_a = _mm512_loadu_ps(a + i); - v_a = _mm512_sub_ps(v_a, v_max); - exp_ps_avx512(v_a); - v_sum = _mm512_add_ps(v_sum, v_a); - _mm512_storeu_ps(a + i, v_a); - i += vec_len_f32_avx512; - } - - if (i < size) { - __mmask16 mask = (1 << (size - i)) - 1; - v_a = _mm512_maskz_loadu_ps(mask, a + i); - v_a = _mm512_sub_ps(v_a, v_max); - exp_ps_avx512(v_a); - v_sum = _mm512_mask_add_ps(v_sum, mask, v_a, v_sum); - _mm512_mask_storeu_ps(a + i, mask, v_a); - } - sum = _mm512_reduce_add_ps(v_sum); -#elif defined(HAVE_AVX2) - size_t i = 0; - __m256 v_a; - auto v_max = _mm256_set1_ps(max); - auto v_sum = _mm256_set1_ps(0.0f); - while (i + vec_len_f32_avx2 <= size) { - v_a = _mm256_loadu_ps(a + i); - v_a = _mm256_sub_ps(v_a, v_max); - exp_ps_avx2(v_a); - v_sum = _mm256_add_ps(v_sum, v_a); - _mm256_storeu_ps(a + i, v_a); - i += vec_len_f32_avx2; - } - - if (i < size) { - auto mask = get_mask(size - i); - v_a = _mm256_maskload_ps(a + i, mask); - v_a = _mm256_sub_ps(v_a, v_max); - exp_ps_avx2(v_a); - v_a = _mm256_blendv_ps(_mm256_setzero_ps(), v_a, _mm256_castsi256_ps(mask)); - v_sum = _mm256_add_ps(v_a, v_sum); - _mm256_maskstore_ps(a + i, mask, v_a); - } - hsum(v_sum); - sum = _mm256_cvtss_f32(v_sum); -#else - for (size_t i = 0; i < size; i++) { - a[i] = exp(a[i] - max); - sum += a[i]; - } -#endif -} - -inline void multiply_scalar(float* a, float* a_dst, const float val, const size_t size) { -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(val); - __m512 v_a = {0}; - size_t i = 0; - while (i + vec_len_f32_avx512 <= size) { - v_a = _mm512_loadu_ps(a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - _mm512_storeu_ps(a_dst + i, v_a); - i += vec_len_f32_avx512; - } - if (i < size) { - __mmask16 mask = (1 << (size - i)) - 1; - v_a = _mm512_maskz_loadu_ps(mask, a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - _mm512_mask_storeu_ps(a_dst + i, mask, v_a); - } -#elif defined(HAVE_AVX2) - auto v_scale = _mm256_set1_ps(val); - __m256 v_a = {0}; - size_t i = 0; - while (i + vec_len_f32_avx2 <= size) { - v_a = _mm256_loadu_ps(a + i); - v_a = _mm256_mul_ps(v_a, v_scale); - _mm256_storeu_ps(a_dst + i, v_a); - i += vec_len_f32_avx2; - } - if (i < size) { - auto mask = get_mask(size - i); - v_a = _mm256_maskload_ps(a + i, mask); - v_a = _mm256_mul_ps(v_a, v_scale); - _mm256_maskstore_ps(a_dst + i, mask, v_a); - } -#else - for (size_t i = 0; i < size; i++) { - a_dst[i] = a[i] * val; - } -#endif -} - -inline void multiply_scalar(float* a, ov::bfloat16* a_dst, const float val, const size_t size) { -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(val); - __m512 v_a = {0}; - size_t i = 0; - while (i + vec_len_f32_avx512 <= size) { - v_a = _mm512_loadu_ps(a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - mm512_uni_storeu_ps(a_dst + i, v_a); - i += vec_len_f32_avx512; - } - if (i < size) { - __mmask16 mask = (1 << (size - i)) - 1; - v_a = _mm512_maskz_loadu_ps(mask, a + i); - v_a = _mm512_mul_ps(v_a, v_scale); - mm512_uni_mask_storeu_ps(a_dst + i, mask, v_a); - } -#else - for (size_t i = 0; i < size; i++) { - a_dst[i] = a[i] * val; - } -#endif -} void attn_softmax(float* a, void* a_dst, @@ -542,38 +33,9 @@ void attn_softmax(float* a, size_t len, size_t total_size, ov::element::Type dst_precision) { - using func_type = void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float&); - static func_type funcs[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; - int dispatch = (alibi ? 0b100 : 0) | (attn_mask ? 0b010 : 0) | (causal_mask ? 0b001 : 0); - float max = std::numeric_limits::lowest(); - funcs[dispatch](a, scale, alibi, attn_mask, causal_mask, select_nfltmax_at_0, len, max); - - float sum = 0.0f; - // exp sum - exp_reduce_sum(a, max, len, sum); - // divide sum - float scalar = 1.0f / sum; - if (dst_precision == ov::element::f32) { - multiply_scalar(a, static_cast(a_dst), scalar, len); - // apply causual mask to final result instead of attn_score - if (total_size > len) - memset(static_cast(a_dst) + len, 0, sizeof(float) * (total_size - len)); - } else { - multiply_scalar(a, static_cast(a_dst), scalar, len); - // apply causual mask to final result instead of attn_score - if (total_size > len) - memset(static_cast(a_dst) + len, 0, sizeof(ov::bfloat16) * (total_size - len)); - } + attn_softmax_kernel(a, a_dst, scale, alibi, attn_mask, causal_mask, select_nfltmax_at_0, len, total_size, dst_precision); } + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp index 46fedd048c32d2..0cfd1591e1e641 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp @@ -1,6 +1,7 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#pragma once #include #include diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp new file mode 100644 index 00000000000000..e698b199872b72 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -0,0 +1,576 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include + +#include "common.hpp" + +namespace InferenceEngine { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +#if defined(HAVE_AVX2) +inline __m256i get_mask(int N7) { + static __m256i mask[] = { + _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), + _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), + _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), + _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), + _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), + _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), + _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), + _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), + _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), + }; + return _mm256_loadu_si256(&mask[N7]); +} + +inline void hmax(__m256& x) { + __m256 y; // x: 0 1 2 3 4 5 6 7 + y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 + x = _mm256_max_ps(x, y); // X: 01 12 23 30 45 56 67 74 + y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 + x = _mm256_max_ps(x, y); // x: 0123 x x x 4567 x x x + y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x + x = _mm256_max_ps(x, y); // x: 01234567 x x x x x x x +} + +inline void exp_ps_avx2(__m256& src) { + static __m256 exp_ln_flt_min_f = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); // log(FLT_MIN) + static __m256 exp_ln_flt_max_f = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); // log(FLT_MAX) + static __m256 exp_log2ef = _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) + static __m256 half = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f000000)); // 0.5f + static __m256 ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) + static __m256 one = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f800000)); // 1.0f + static __m256i exponent_bias = _mm256_set1_epi32(0x0000007f); // 127 + static constexpr int n_mantissa_bits = 23; + static __m256 exp_pol1 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f + static __m256 exp_pol2 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3efffee3)); // p2 = 0.499991506f + static __m256 exp_pol3 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f + static __m256 exp_pol4 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f + static __m256 exp_pol5 = _mm256_castsi256_ps(_mm256_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f + static __m256 two = _mm256_castsi256_ps(_mm256_set1_epi32(0x40000000)); // 2 + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + // get mask of values lower than log(FLT_MIN) to zero them in the output + auto zero_mask = _mm256_cmp_ps(src, exp_ln_flt_min_f, _CMP_LT_OS); + + // clip src + src = _mm256_min_ps(src, exp_ln_flt_max_f); + src = _mm256_max_ps(src, exp_ln_flt_min_f); + + // aux1 : r + auto aux1 = src; + + // calculate exp(x) + // fx = x * log2(e) + 0.5 + src = _mm256_mul_ps(src, exp_log2ef); + src = _mm256_add_ps(src, half); + + // tmp = floorf(fx) + src = _mm256_floor_ps(src); + + // aux1 = x - fx * ln2 + aux1 = _mm256_fnmadd_ps(src, ln2f, aux1); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + src = _mm256_sub_ps(src, one); + auto aux2_i = _mm256_cvtps_epi32(src); + aux2_i = _mm256_add_epi32(aux2_i, exponent_bias); + aux2_i = _mm256_slli_epi32(aux2_i, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + auto zero = _mm256_setzero_ps(); + auto aux2 = _mm256_blendv_ps(_mm256_castsi256_ps(aux2_i), zero, zero_mask); + + // compute polynomial + src = exp_pol5; + src = _mm256_fmadd_ps(src, aux1, exp_pol4); + src = _mm256_fmadd_ps(src, aux1, exp_pol3); + src = _mm256_fmadd_ps(src, aux1, exp_pol2); + src = _mm256_fmadd_ps(src, aux1, exp_pol1); + src = _mm256_fmadd_ps(src, aux1, one); + + // y = y * 2^n + src = _mm256_mul_ps(src, aux2); + src = _mm256_mul_ps(src, two); +} +#endif + +inline void scale_add_reduce_max(float* a, const float scale, const float* b, const size_t size, float& max) { +#if defined(HAVE_AVX512F) + auto v_max = _mm512_set1_ps(std::numeric_limits::lowest()); + auto v_scale = _mm512_set1_ps(scale); + auto v_a = v_max; + auto v_b = v_max; + size_t i = 0; + // process vector body + while (i + vec_len_f32_avx512 <= size) { + v_a = _mm512_loadu_ps(a + i); + v_b = _mm512_loadu_ps(b + i); + v_a = _mm512_fmadd_ps(v_a, v_scale, v_b); + v_max = _mm512_max_ps(v_max, v_a); + _mm512_storeu_ps(a + i, v_a); + i += vec_len_f32_avx512; + } + + // process tails + if (i < size) { + __mmask16 mask = (1 << (size - i)) - 1; + v_a = _mm512_maskz_loadu_ps(mask, a + i); + v_b = _mm512_maskz_loadu_ps(mask, b + i); + v_a = _mm512_fmadd_ps(v_a, v_scale, v_b); + v_max = _mm512_mask_max_ps(v_max, mask, v_a, v_max); + _mm512_mask_storeu_ps(a + i, mask, v_a); + } + + max = _mm512_reduce_max_ps(v_max); +#elif defined(HAVE_AVX2) + auto v_max = _mm256_set1_ps(std::numeric_limits::lowest()); + auto v_scale = _mm256_set1_ps(scale); + auto v_a = v_max; + auto v_b = v_max; + size_t i = 0; + // process vector body + while (i + vec_len_f32_avx2 <= size) { + v_a = _mm256_loadu_ps(a + i); + v_b = _mm256_loadu_ps(b + i); + v_a = _mm256_fmadd_ps(v_a, v_scale, v_b); + v_max = _mm256_max_ps(v_max, v_a); + _mm256_storeu_ps(a + i, v_a); + i += vec_len_f32_avx2; + } + + // process tails + if (i < size) { + auto mask = get_mask(size - i); + v_a = _mm256_maskload_ps(a + i, mask); + v_b = _mm256_maskload_ps(b + i, mask); + v_a = _mm256_fmadd_ps(v_a, v_scale, v_b); + v_a = _mm256_blendv_ps(v_max, v_a, _mm256_castsi256_ps(mask)); + v_max = _mm256_max_ps(v_max, v_a); + _mm256_maskstore_ps(a + i, mask, v_a); + } + hmax(v_max); + max = _mm256_cvtss_f32(v_max); +#else + for (size_t i = 0; i < size; i++) { + a[i] *= scale; + a[i] += b[i]; + max = a[i] > max ? a[i] : max; + } +#endif +} +template +inline void scale_add2_reduce_max(float* a, + float scale, + const float* alibi, + const float* attn_mask, + const uint8_t* causal_mask, + bool select_nfltmax_at_0, // true: 0 in mask set -FLT_MAX + size_t size, + float& max) { +#if defined(HAVE_AVX512F) + auto v_max = _mm512_set1_ps(std::numeric_limits::lowest()); + auto v_scale = _mm512_set1_ps(scale); + auto v_a = v_max; + size_t i = 0; + auto v_zeroi32 = _mm512_setzero_epi32(); + auto v_nfltmax = _mm512_set1_ps(-FLT_MAX); + auto kmask_xor = _cvtu32_mask16(select_nfltmax_at_0 ? 0xFFFF : 0); + // process vector body + while (i + vec_len_f32_avx512 <= size) { + v_a = _mm512_loadu_ps(a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + + if (has_alibi) { + auto v_mask = _mm512_loadu_ps(alibi + i); + v_a = _mm512_add_ps(v_a, v_mask); + } + + if (has_attn_mask) { + auto v_mask = _mm512_loadu_ps(attn_mask + i); + v_a = _mm512_add_ps(v_a, v_mask); + } + + if (has_causal_mask) { + auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); + auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); + auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); // !=0 + kmask = _kxor_mask16(kmask, kmask_xor); // reverse, mask at ==0 + v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); // mask => -FLT_MAX + } + v_max = _mm512_max_ps(v_max, v_a); + _mm512_storeu_ps(a + i, v_a); + i += vec_len_f32_avx512; + } + + // process tails + if (i < size) { + __mmask16 mask = (1 << (size - i)) - 1; + v_a = _mm512_maskz_loadu_ps(mask, a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + + if (has_alibi) { + auto v_mask = _mm512_maskz_loadu_ps(mask, alibi + i); + v_a = _mm512_add_ps(v_a, v_mask); + } + + if (has_attn_mask) { + auto v_mask = _mm512_maskz_loadu_ps(mask, attn_mask + i); + v_a = _mm512_add_ps(v_a, v_mask); + } + + if (has_causal_mask) { + auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); + auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); + auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); // !=0 + kmask = _kxor_mask16(kmask, kmask_xor); // reverse, mask at ==0 + v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); // mask => -FLT_MAX + } + v_max = _mm512_mask_max_ps(v_max, mask, v_a, v_max); + _mm512_mask_storeu_ps(a + i, mask, v_a); + } + + max = _mm512_reduce_max_ps(v_max); +#elif defined(HAVE_AVX2) + auto v_max = _mm256_set1_ps(std::numeric_limits::lowest()); + auto v_scale = _mm256_set1_ps(scale); + auto v_a = v_max; + auto v_zeroi32 = _mm256_setzero_si256(); + auto v_mask_xor = _mm256_set1_epi32(select_nfltmax_at_0 ? -1 : 0); + auto v_nfltmax = _mm256_set1_ps(-FLT_MAX); + size_t i = 0; + // process vector body + while (i + vec_len_f32_avx2 <= size) { + v_a = _mm256_loadu_ps(a + i); + v_a = _mm256_mul_ps(v_a, v_scale); + + if (has_alibi) { + auto v_mask = _mm256_loadu_ps(alibi + i); + v_a = _mm256_add_ps(v_a, v_mask); + } + + if (has_attn_mask) { + auto v_mask = _mm256_loadu_ps(attn_mask + i); + v_a = _mm256_add_ps(v_a, v_mask); + } + + if (has_causal_mask) { + auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); + auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); + v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32); // ==0 + v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor); // reverse, mask at ==0 + v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); // mask => -FLT_MAX + } + + v_max = _mm256_max_ps(v_max, v_a); + _mm256_storeu_ps(a + i, v_a); + i += vec_len_f32_avx2; + } + + // process tails + if (i < size) { + auto mask = get_mask(size - i); + v_a = _mm256_maskload_ps(a + i, mask); + v_a = _mm256_mul_ps(v_a, v_scale); + + if (has_alibi) { + auto v_mask = _mm256_maskload_ps(alibi + i, mask); + v_a = _mm256_add_ps(v_a, v_mask); + } + + if (has_attn_mask) { + auto v_mask = _mm256_maskload_ps(attn_mask + i, mask); + v_a = _mm256_add_ps(v_a, v_mask); + } + + if (has_causal_mask) { + auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i)); + auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); + v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32); // ==0 + v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor); // reverse, mask at ==0 + v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); // mask => -FLT_MAX + } + + v_a = _mm256_blendv_ps(v_max, v_a, _mm256_castsi256_ps(mask)); + v_max = _mm256_max_ps(v_max, v_a); + _mm256_maskstore_ps(a + i, mask, v_a); + } + hmax(v_max); + max = _mm256_cvtss_f32(v_max); +#else + for (size_t i = 0; i < size; i++) { + a[i] *= scale; + if (has_alibi) + a[i] += alibi[i]; + + if (has_attn_mask) + a[i] += attn_mask[i]; + + if (has_causal_mask) { + if (select_nfltmax_at_0) { + if (causal_mask[i] == 0) + a[i] = -FLT_MAX; + } else { + if (causal_mask[i] != 0) + a[i] = -FLT_MAX; + } + } + + max = a[i] > max ? a[i] : max; + } +#endif +} + +#if defined(HAVE_AVX512F) +inline void exp_ps_avx512(__m512& src) { + static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) + static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) + static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + static __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); // 0.5f + static __m512 ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + static __m512i exponent_bias = _mm512_set1_epi32(0x0000007f); // 127 + static constexpr int n_mantissa_bits = 23; + static __m512 exp_pol1 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f + static __m512 exp_pol2 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3efffee3)); // p2 = 0.499991506f + static __m512 exp_pol3 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f + static __m512 exp_pol4 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f + static __m512 exp_pol5 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f + static __m512 two = _mm512_castsi512_ps(_mm512_set1_epi32(0x40000000)); // 2 + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + // get mask of values lower than log(FLT_MIN) to zero them in the output + auto zero_mask = _mm512_cmp_ps_mask(src, exp_ln_flt_min_f, _CMP_LT_OS); + + // clip src + src = _mm512_min_ps(src, exp_ln_flt_max_f); + src = _mm512_max_ps(src, exp_ln_flt_min_f); + + // aux1 : r + auto aux1 = src; + + // calculate exp(x) + // fx = x * log2(e) + 0.5 + src = _mm512_mul_ps(src, exp_log2ef); + src = _mm512_add_ps(src, half); + + // tmp = floorf(fx) + src = _mm512_floor_ps(src); + + // aux1 = x - fx * ln2 + aux1 = _mm512_fnmadd_ps(src, ln2f, aux1); + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + src = _mm512_sub_ps(src, one); + auto aux2_i = _mm512_cvtps_epi32(src); + aux2_i = _mm512_add_epi32(aux2_i, exponent_bias); + aux2_i = _mm512_slli_epi32(aux2_i, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + auto zero = _mm512_setzero_ps(); + auto aux2 = _mm512_mask_blend_ps(zero_mask, _mm512_castsi512_ps(aux2_i), zero); + + // compute polynomial + src = exp_pol5; + src = _mm512_fmadd_ps(src, aux1, exp_pol4); + src = _mm512_fmadd_ps(src, aux1, exp_pol3); + src = _mm512_fmadd_ps(src, aux1, exp_pol2); + src = _mm512_fmadd_ps(src, aux1, exp_pol1); + src = _mm512_fmadd_ps(src, aux1, one); + + // y = y * 2^n + src = _mm512_mul_ps(src, aux2); + src = _mm512_mul_ps(src, two); +} +#endif + +inline void exp_reduce_sum(float* a, const float max, const size_t size, float& sum) { +#if defined(HAVE_AVX512F) + size_t i = 0; + __m512 v_a; + auto v_max = _mm512_set1_ps(max); + auto v_sum = _mm512_set1_ps(0.0f); + while (i + vec_len_f32_avx512 <= size) { + v_a = _mm512_loadu_ps(a + i); + v_a = _mm512_sub_ps(v_a, v_max); + exp_ps_avx512(v_a); + v_sum = _mm512_add_ps(v_sum, v_a); + _mm512_storeu_ps(a + i, v_a); + i += vec_len_f32_avx512; + } + + if (i < size) { + __mmask16 mask = (1 << (size - i)) - 1; + v_a = _mm512_maskz_loadu_ps(mask, a + i); + v_a = _mm512_sub_ps(v_a, v_max); + exp_ps_avx512(v_a); + v_sum = _mm512_mask_add_ps(v_sum, mask, v_a, v_sum); + _mm512_mask_storeu_ps(a + i, mask, v_a); + } + sum = _mm512_reduce_add_ps(v_sum); +#elif defined(HAVE_AVX2) + size_t i = 0; + __m256 v_a; + auto v_max = _mm256_set1_ps(max); + auto v_sum = _mm256_set1_ps(0.0f); + while (i + vec_len_f32_avx2 <= size) { + v_a = _mm256_loadu_ps(a + i); + v_a = _mm256_sub_ps(v_a, v_max); + exp_ps_avx2(v_a); + v_sum = _mm256_add_ps(v_sum, v_a); + _mm256_storeu_ps(a + i, v_a); + i += vec_len_f32_avx2; + } + + if (i < size) { + auto mask = get_mask(size - i); + v_a = _mm256_maskload_ps(a + i, mask); + v_a = _mm256_sub_ps(v_a, v_max); + exp_ps_avx2(v_a); + v_a = _mm256_blendv_ps(_mm256_setzero_ps(), v_a, _mm256_castsi256_ps(mask)); + v_sum = _mm256_add_ps(v_a, v_sum); + _mm256_maskstore_ps(a + i, mask, v_a); + } + hsum(v_sum); + sum = _mm256_cvtss_f32(v_sum); +#else + for (size_t i = 0; i < size; i++) { + a[i] = exp(a[i] - max); + sum += a[i]; + } +#endif +} + +inline void multiply_scalar(float* a, float* a_dst, const float val, const size_t size) { +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(val); + __m512 v_a = {0}; + size_t i = 0; + while (i + vec_len_f32_avx512 <= size) { + v_a = _mm512_loadu_ps(a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + _mm512_storeu_ps(a_dst + i, v_a); + i += vec_len_f32_avx512; + } + if (i < size) { + __mmask16 mask = (1 << (size - i)) - 1; + v_a = _mm512_maskz_loadu_ps(mask, a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + _mm512_mask_storeu_ps(a_dst + i, mask, v_a); + } +#elif defined(HAVE_AVX2) + auto v_scale = _mm256_set1_ps(val); + __m256 v_a = {0}; + size_t i = 0; + while (i + vec_len_f32_avx2 <= size) { + v_a = _mm256_loadu_ps(a + i); + v_a = _mm256_mul_ps(v_a, v_scale); + _mm256_storeu_ps(a_dst + i, v_a); + i += vec_len_f32_avx2; + } + if (i < size) { + auto mask = get_mask(size - i); + v_a = _mm256_maskload_ps(a + i, mask); + v_a = _mm256_mul_ps(v_a, v_scale); + _mm256_maskstore_ps(a_dst + i, mask, v_a); + } +#else + for (size_t i = 0; i < size; i++) { + a_dst[i] = a[i] * val; + } +#endif +} + +inline void multiply_scalar(float* a, ov::bfloat16* a_dst, const float val, const size_t size) { +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(val); + __m512 v_a = {0}; + size_t i = 0; + while (i + vec_len_f32_avx512 <= size) { + v_a = _mm512_loadu_ps(a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + mm512_uni_storeu_ps(a_dst + i, v_a); + i += vec_len_f32_avx512; + } + if (i < size) { + __mmask16 mask = (1 << (size - i)) - 1; + v_a = _mm512_maskz_loadu_ps(mask, a + i); + v_a = _mm512_mul_ps(v_a, v_scale); + mm512_uni_mask_storeu_ps(a_dst + i, mask, v_a); + } +#else + for (size_t i = 0; i < size; i++) { + a_dst[i] = a[i] * val; + } +#endif +} + +inline void attn_softmax_kernel(float* a, + void* a_dst, + float scale, + float* alibi, + float* attn_mask, + uint8_t* causal_mask, + bool select_nfltmax_at_0, + size_t len, + size_t total_size, + ov::element::Type dst_precision) { + using func_type = void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float&); + static func_type funcs[] = { + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max + }; + int dispatch = (alibi ? 0b100 : 0) | (attn_mask ? 0b010 : 0) | (causal_mask ? 0b001 : 0); + float max = std::numeric_limits::lowest(); + funcs[dispatch](a, scale, alibi, attn_mask, causal_mask, select_nfltmax_at_0, len, max); + + float sum = 0.0f; + // exp sum + exp_reduce_sum(a, max, len, sum); + // divide sum + float scalar = 1.0f / sum; + if (dst_precision == ov::element::f32) { + multiply_scalar(a, static_cast(a_dst), scalar, len); + // apply causual mask to final result instead of attn_score + if (total_size > len) + memset(static_cast(a_dst) + len, 0, sizeof(float) * (total_size - len)); + } else { + multiply_scalar(a, static_cast(a_dst), scalar, len); + // apply causual mask to final result instead of attn_score + if (total_size > len) + memset(static_cast(a_dst) + len, 0, sizeof(ov::bfloat16) * (total_size - len)); + } +} + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace InferenceEngine \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp index dafd7f2829a58c..189b9a05023706 100644 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -38,11 +38,11 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor { const RoPENode::Config& config, const std::vector& inputs, const std::vector& outputs) override { - ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_cos(inputs[1]); - ov::intel_cpu::PlainTensor t_sin(inputs[2]); - ov::intel_cpu::PlainTensor t_dst(outputs[0]); - ov::intel_cpu::PlainTensor gather; + ov::intel_cpu::PlainTensor t_src(inputs[0]); + ov::intel_cpu::PlainTensor t_cos(inputs[1]); + ov::intel_cpu::PlainTensor t_sin(inputs[2]); + ov::intel_cpu::PlainTensor t_dst(outputs[0]); + ov::intel_cpu::PlainTensor gather; if (config.slice_stop - config.slice_start > 0) { t_src = t_src.slice(3, config.slice_start, config.slice_stop); @@ -73,14 +73,14 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor { auto cos_pos = p; if (gather) { if (gather.m_rank == 4) - cos_pos = gather.at({b, h, p, 0}, true); + cos_pos = gather.at({b, h, p, 0}, true); else - cos_pos = gather.at({b, p}, true); + cos_pos = gather.at({b, p}, true); } - auto* src = &t_src.at({b, h, p, 0}); - auto* cos = &t_cos.at({b, h, cos_pos, 0}, true); - auto* sin = &t_sin.at({b, h, cos_pos, 0}, true); - auto* dst = &t_dst.at({b, h, p, 0}); + auto* src = &t_src.at({b, h, p, 0}); + auto* cos = &t_cos.at({b, h, cos_pos, 0}, true); + auto* sin = &t_sin.at({b, h, cos_pos, 0}, true); + auto* dst = &t_dst.at({b, h, p, 0}); size_t i = 0; for (; i < half_rotary_dims; i++) { @@ -102,9 +102,9 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor { const RoPENode::Config& config, const std::vector& inputs, const std::vector& outputs) override { - ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); - ov::intel_cpu::PlainTensor t_dst(outputs[0]); + ov::intel_cpu::PlainTensor t_src(inputs[0]); + ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); + ov::intel_cpu::PlainTensor t_dst(outputs[0]); auto batch_size = t_src.size(0); auto seq_len = t_src.size(1); @@ -114,10 +114,10 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor { auto rotary_dims = config.rotary_ndims; auto half_rotary_dims = rotary_dims / 2; parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { - auto* x = &t_src.at({b, p, h, 0}); - float* sin = &t_sin_cos.at({b, p, 0}, true); - float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); - auto* dst = &t_dst.at({b, h, p, 0}); + auto* x = &t_src.at({b, p, h, 0}); + float* sin = &t_sin_cos.at({b, p, 0}, true); + float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); + auto* dst = &t_dst.at({b, h, p, 0}); size_t i = 0; for (size_t j = 0; i < rotary_dims; i += 2, j++) { @@ -137,9 +137,9 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor { const RoPENode::Config& config, const std::vector& inputs, const std::vector& outputs) override { - ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_cos_sin(inputs[1]); - ov::intel_cpu::PlainTensor t_dst(outputs[0]); + ov::intel_cpu::PlainTensor t_src(inputs[0]); + ov::intel_cpu::PlainTensor t_cos_sin(inputs[1]); + ov::intel_cpu::PlainTensor t_dst(outputs[0]); // [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)] if (config.slice_stop - config.slice_start > 0) { @@ -154,10 +154,10 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor { auto rotary_dims = config.rotary_ndims; parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { - auto* src = &t_src.at({p, b, h * head_size}); + auto* src = &t_src.at({p, b, h * head_size}); // [length, batch_size, ndims//2, 2] - auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); - auto* dst = &t_dst.at({p, b, h, 0}); + auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); + auto* dst = &t_dst.at({p, b, h, 0}); size_t i = 0; for (; i < rotary_dims; i += 2) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 2ba51286e79bd2..c2d1ef17143337 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -28,12 +28,12 @@ #include "utils/plain_tensor.hpp" #include "kernels/scaled_attn/softmax.hpp" -#include "kernels/scaled_attn/dot_product.hpp" -#include "kernels/scaled_attn/acc_value.hpp" -#include "kernels/scaled_attn/reduce.hpp" +#include "kernels/scaled_attn/mha_single_token.hpp" +#include "kernels/scaled_attn/attn_memcpy.hpp" using namespace InferenceEngine; using namespace InferenceEngine::Extensions::Cpu::XARCH; +using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { @@ -77,9 +77,9 @@ struct MHAKernel { } } - PlainTensor causal_mask; + PlainTensor causal_mask; bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this - void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { + void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { causal_mask = mask; select_nfltmax_at_0 = _select_nfltmax_at_0; } @@ -91,12 +91,12 @@ struct MHAKernel { // attention_mask [B, 1, q_len, kv_len] // output_emb [B, q_len, H*S] void operator()(dnnl::stream strm, - PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - const PlainTensor& alibi_mask, - const PlainTensor& attention_mask, - PlainTensor& output_emb, + PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, float d_scale = 0.0f) { @@ -117,27 +117,27 @@ struct MHAKernel { for (size_t m = 0; m < q_len; m++) { // dot-product to get attention scores - auto* q = &query.at({b, h, m, 0}); + auto* q = &query.at({b, h, m, 0}); // how many key/values can be accessed causally auto ncausal = kv_len; // no causall mask is set and it's not fused into attention_mask if (auto_causal) ncausal = kv_len - q_len + m + 1; for (size_t n = 0; n < ncausal; n++) { - auto* k = &present_key.at({b, h, n, 0}, true); + auto* k = &present_key.at({b, h, n, 0}, true); attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale; // apply alibi tensor if (alibi_mask) - attn_score[n] += alibi_mask.at({b, h, m, n}, true); + attn_score[n] += alibi_mask.at({b, h, m, n}, true); // apply attention mask (maybe combined with causal_mask) if (attention_mask) - attn_score[n] += attention_mask.at({b, h, m, n}, true); + attn_score[n] += attention_mask.at({b, h, m, n}, true); // apply causal_mask if (causal_mask) { - bool is_zero = causal_mask.at({b, h, m, n}, true) == 0; + bool is_zero = causal_mask.at({b, h, m, n}, true) == 0; if (select_nfltmax_at_0) { if (is_zero) attn_score[n] = -FLT_MAX; @@ -155,12 +155,12 @@ struct MHAKernel { // linearly combine value word_vec.assign(head_size, 0.0f); for (size_t n = 0; n < ncausal; n++) { - auto* v = &present_value.at({b, h, n, 0}, true); + auto* v = &present_value.at({b, h, n, 0}, true); accumulate(word_vec.data(), v, head_size, attn_score[n]); } // output [B, L1, H*head_size] - auto* out = has_out_transpose ? &output_emb.at({b, m, h * head_size}) : &output_emb.at({b, h, m}); + auto* out = has_out_transpose ? &output_emb.at({b, m, h * head_size}) : &output_emb.at({b, h, m}); std::copy(word_vec.begin(), word_vec.end(), out); } }); @@ -218,23 +218,23 @@ struct MHAKernel { } } - void exec_qk(dnnl::stream strm, PlainTensor& query, PlainTensor& present_key) { - dnnl::memory q(q_md, strm.get_engine(), query.data()); - dnnl::memory k(k_md, strm.get_engine(), present_key.data()); + void exec_qk(dnnl::stream strm, PlainTensor& query, PlainTensor& present_key) { + dnnl::memory q(q_md, strm.get_engine(), query.data()); + dnnl::memory k(k_md, strm.get_engine(), present_key.data()); qk_prim.execute(strm, {{DNNL_ARG_SRC, q}, {DNNL_ARG_WEIGHTS, k}, {DNNL_ARG_DST, attn_score}}); } - void exec_kv(dnnl::stream strm, PlainTensor& present_value, PlainTensor& output_emb) { - dnnl::memory v(v_md, strm.get_engine(), present_value.data()); - dnnl::memory out(out_md, strm.get_engine(), output_emb.data()); + void exec_kv(dnnl::stream strm, PlainTensor& present_value, PlainTensor& output_emb) { + dnnl::memory v(v_md, strm.get_engine(), present_value.data()); + dnnl::memory out(out_md, strm.get_engine(), output_emb.data()); wv_prim.execute(strm, {{DNNL_ARG_SRC, attn_weight}, {DNNL_ARG_WEIGHTS, v}, {DNNL_ARG_DST, out}}); } - PlainTensor causal_mask; + PlainTensor causal_mask; bool select_nfltmax_at_0 = false; // set attn_score to -FLT_MAX when causal_mask[...] equal to this - void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { + void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { causal_mask = mask; select_nfltmax_at_0 = _select_nfltmax_at_0; } @@ -247,12 +247,12 @@ struct MHAKernel { // alibi [B, H, q_len, kv_len] // output_emb [B, L1, H*S] void operator()(dnnl::stream strm, - PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - const PlainTensor& alibi_mask, - const PlainTensor& attention_mask, - PlainTensor& output_emb, + PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, float d_scale = 0.0f) { @@ -269,20 +269,20 @@ struct MHAKernel { prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose); exec_qk(strm, query, present_key); - PlainTensor score; + PlainTensor score; score.resize({B, H, q_len, kv_len}, static_cast(attn_score.get_data_handle())); - PlainTensor weight; + PlainTensor weight; weight.resize({B, H, q_len, kv_len}, static_cast(attn_weight.get_data_handle())); // softmax parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t m) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - attn_softmax(&score.at({b, h, m, 0}), - &weight.at({b, h, m, 0}), + attn_softmax(&score.at({b, h, m, 0}), + &weight.at({b, h, m, 0}), d_scale, - alibi_mask ? &alibi_mask.at({b, h, m, 0}, true) : nullptr, - attention_mask ? &attention_mask.at({b, h, m, 0}, true) : nullptr, - causal_mask ? &causal_mask.at({b, h, m, 0}, true) : nullptr, + alibi_mask ? &alibi_mask.at({b, h, m, 0}, true) : nullptr, + attention_mask ? &attention_mask.at({b, h, m, 0}, true) : nullptr, + causal_mask ? &causal_mask.at({b, h, m, 0}, true) : nullptr, select_nfltmax_at_0, ncausal, kv_len, @@ -297,17 +297,17 @@ template <> struct MHAKernel { size_t m_block_size; // buffer to hold qk temp - std::vector> qk_buffers; + std::vector qk_buffers; MHAKernel() { m_block_size = 4; select_nfltmax_at_0 = false; - qk_buffers.resize(parallel_get_max_threads(), PlainTensor(true)); + qk_buffers.resize(parallel_get_max_threads(), PlainTensor(true)); } - PlainTensor causal_mask; + PlainTensor causal_mask; bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this - void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { + void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { causal_mask = mask; select_nfltmax_at_0 = _select_nfltmax_at_0; } @@ -320,12 +320,12 @@ struct MHAKernel { // alibi // output_emb [B, L1, H*S] void operator()(dnnl::stream strm, - PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - const PlainTensor& alibi_mask, - const PlainTensor& attention_mask, - PlainTensor& output_emb, + PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, float d_scale = 0.0f) { @@ -354,34 +354,34 @@ struct MHAKernel { auto m_cnt = m_end - m_start; auto kv_len_cache_align = (((kv_len * sizeof(float)) + 63) / 64 * 64) / sizeof(float); - qk_buf.resize({m_block_size, kv_len_cache_align}); - const float* q_ptr = &query.at({b, h, m_start, 0}); - const float* k_ptr = &present_key.at({b, h / h_each_group_len, 0, 0}); - const float* v_ptr = &present_value.at({b, h / h_each_group_len, 0, 0}); + qk_buf.resize({m_block_size, kv_len_cache_align}); + const float* q_ptr = &query.at({b, h, m_start, 0}); + const float* k_ptr = &present_key.at({b, h / h_each_group_len, 0, 0}); + const float* v_ptr = &present_value.at({b, h / h_each_group_len, 0, 0}); float* alibi_ptr = nullptr; auto alibi_stride = 0; if (alibi_mask) { - alibi_ptr = &alibi_mask.at({b, h, 0, 0}, true); + alibi_ptr = &alibi_mask.at({b, h, 0, 0}, true); if (alibi_mask.size(2) > 1) alibi_stride = alibi_mask.stride(2); } float* attn_mask_ptr = nullptr; auto attn_mask_stride = 0; if (attention_mask) { - attn_mask_ptr = &attention_mask.at({b, h, 0, 0}, true); + attn_mask_ptr = &attention_mask.at({b, h, 0, 0}, true); if (attention_mask.size(2) > 1) attn_mask_stride = attention_mask.stride(2); } uint8_t* cmask_ptr = nullptr; auto cmask_stride = 0; if (causal_mask) { - cmask_ptr = &causal_mask.at({b, h, 0, 0}, true); + cmask_ptr = &causal_mask.at({b, h, 0, 0}, true); if (causal_mask.size(2) > 1) cmask_stride = causal_mask.stride(2); } - float* qk = &(qk_buf.at({0, 0})); + float* qk = &(qk_buf.at({0, 0})); auto qk_m_stride = qk_buf.stride(0); if (k_stride_s == 1) @@ -440,7 +440,7 @@ struct MHAKernel { v_ptr, present_value.stride(2), 0.f, - has_out_transpose ? &output_emb.at({b, m_start, h * head_size}) : &output_emb.at({b, h, m_start}), + has_out_transpose ? &output_emb.at({b, m_start, h * head_size}) : &output_emb.at({b, h, m_start}), has_out_transpose ? output_emb.stride(1) : output_emb.stride(2), 1); }); @@ -449,143 +449,45 @@ struct MHAKernel { #endif // 2nd token case : only 1 token in query -template struct MHASingleToken { - PlainTensor m_attn_w; - PlainTensor m_temp; + PlainTensor m_attn_w; + PlainTensor m_temp; - MHASingleToken() : m_attn_w(true), m_temp(true), select_nfltmax_at_0(false) {} - - PlainTensor causal_mask; - bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this - void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { - causal_mask = mask; - select_nfltmax_at_0 = _select_nfltmax_at_0; - } + MHASingleToken() : m_attn_w(true), m_temp(true) {} // Q, K, V is ready, do attention // query [B, H, q_len, S] // present_key [B, H, kv_len, S] stride of last dim maybe > 1 // present_value [B, H, kv_len, S] - // attention_mask [B, 1, q_len, kv_len] // alibi - // output_emb [B, L1, H*S] - void operator()(PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - const PlainTensor& alibi_mask, - const PlainTensor& attention_mask, - PlainTensor& output_emb, - const PlainTensor& beams, + // attention_mask [B, 1, q_len, kv_len] + // output_emb [B, L1, H, S] + void operator()(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, + const PlainTensor& beams, bool has_out_transpose, bool auto_causal, float d_scale = 0.0f) { - auto B = query.size(0); - auto H = query.size(1); - auto q_len = query.size(2); - auto S = query.size(3); - auto kv_len = present_key.size(2); - - if (d_scale == 0.0f) - d_scale = 1.0f / sqrt(S); - - // use per-token kernel, for each k,v token - // attn mask is a matrix of q_len(kv_len) - m_attn_w.resize({B, H, q_len, kv_len}); - - parallel_for3d(B, H, kv_len, [&](size_t b, size_t h, size_t pk) { - // which batch item should be used at postion pk? - auto b_kv = beams ? beams.at({b, pk}) : b; - std::vector as(q_len), bs(q_len); - std::vector cs(q_len); - for (size_t pq = 0; pq < q_len; pq++) { - as[pq] = &query.at({b, h, pq, 0}); - bs[pq] = &present_key.at({b_kv, h, pk, 0}, true); - cs[pq] = &m_attn_w.at({b, h, pq, pk}); - } - attn_dot_products(reinterpret_cast(as.data()), - reinterpret_cast(bs.data()), - reinterpret_cast(cs.data()), - q_len, - S, - precision_of::value); - }); - - parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { - // apply attention mask & sofmax - auto ncausal = auto_causal ? (kv_len - q_len + pq + 1) : kv_len; - float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; - float* attn_mask_ptr = attention_mask ? &attention_mask.at({b, h, pq, 0}, true) : nullptr; - uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; - attn_softmax(&m_attn_w.at({b, h, pq, 0}), - &m_attn_w.at({b, h, pq, 0}), - d_scale, - alibi_ptr, - attn_mask_ptr, - cmask_ptr, - select_nfltmax_at_0, - ncausal, - kv_len, - ov::element::f32); - }); - - // attn_w * V - auto nthr = parallel_get_max_threads(); - m_temp.resize({static_cast(nthr), B, q_len, H, S}); - // m_attn_w {B, H, q_len, kv_len} - parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { - size_t start{0}, end{0}; - splitter(B * H * kv_len, nthr, ithr, start, end); - - memset(&m_temp.at({ithr, 0, 0, 0, 0}), 0, m_temp.stride(0) * sizeof(float)); - - size_t b, h, pv; - if (start < end) { - parallel_it_init(start, b, B, h, H, pv, kv_len); - std::vector vs(q_len * (end - start)); - std::vector weights(q_len * (end - start)); - std::vector outs(q_len * (end - start)); - size_t idx = 0; - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.at({b, pv}) : b; - auto* v = &present_value.at({b_kv, h, pv, 0}, true); - for (size_t pq = 0; pq < q_len; pq++) { - outs[idx] = &m_temp.at({ithr, b, pq, h, 0}); - weights[idx] = m_attn_w.at({b, h, pq, pv}); - vs[idx] = v; - idx++; - } - parallel_it_step(b, B, h, H, pv, kv_len); - } - attn_acc_values(outs.data(), - weights.data(), - reinterpret_cast(vs.data()), - q_len * (end - start), - S, - precision_of::value); - } - }); - - parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { - auto* temp = &m_temp.at({0, b, pq, h, 0}); - size_t temp_stride = m_temp.stride(0); - auto* dst = has_out_transpose ? &output_emb.at({b, pq, h*S}) : &output_emb.at({b, h, pq}); - attn_reduce(dst, temp, nthr, S, temp_stride, precision_of::value); - }); + mha_single_token(query, present_key, present_value, alibi_mask, attention_mask, beams, output_emb, + m_attn_w, m_temp, has_out_transpose, auto_causal, d_scale); } }; template struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor { - PlainTensor q_input; // f32[B, H, L1, S] - PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] - PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] - PlainTensor beam_table; // i32[B, max_kvLen] - PlainTensor attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1] + PlainTensor q_input; // f32[B, H, L1, S] + PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor beam_table; // i32[B, max_kvLen] + PlainTensor attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1] float scale_input = 0.0f; MHAKernel kernel; - MHASingleToken kernel_single_token; + MHASingleToken kernel_single_token; size_t B, H, L1, L0, S; @@ -593,49 +495,33 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt AttentionExecutor(const Config& _config) : attn_buf(true), config(_config) {} void prepare_attn_mask(MemoryPtr attn_input) { - attn_buf.resize(attn_input->getStaticDims()); + attn_buf.resize(attn_input->getStaticDims()); auto p = reinterpret_cast(attn_input->getData()); for (size_t i = 0; i < attn_input->getSize(); i++) - attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX; + attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX; } void concat_pastkv(const std::vector& inputs, const std::vector& outputs, - const PlainTensor& k_input, - const PlainTensor& v_input, - PlainTensor& past_k_output, - PlainTensor& past_v_output) { + const PlainTensor& k_input, + const PlainTensor& v_input, + PlainTensor& past_k_output, + PlainTensor& past_v_output) { if (config.config.fuse_concat) { k_input.assert_dims({B, 0, L1, S}, true); v_input.assert_dims({B, 0, L1, S}, true); auto past_k_idx = inputs.size() - 2; auto past_k_mem = inputs[past_k_idx + 0]; L0 = past_k_mem->getStaticDims()[2]; - // k,v may support multiquery - auto Hk = past_k_mem->getStaticDims()[1]; // [B, H, L0, S] past_k_output.reset(outputs[1]); past_v_output.reset(outputs[2]); - parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) { - std::memcpy(&past_k_output.at({b, h, m + L0, 0}), - &k_input.at({b, h, m, 0}), - S * sizeof(T)); - std::memcpy(&past_v_output.at({b, h, m + L0, 0}), - &v_input.at({b, h, m, 0}), - S * sizeof(T)); - }); + attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1)); if (!config.is_concat_inplaced) { - PlainTensor past_k_input, past_v_input; + PlainTensor past_k_input, past_v_input; past_k_input.reset(past_k_mem); past_v_input.reset(inputs[past_k_idx + 1]); - parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) { - std::memcpy(&past_k_output.at({b, h, m, 0}), - &past_k_input.at({b, h, m, 0}), - S * sizeof(T)); - std::memcpy(&past_v_output.at({b, h, m, 0}), - &past_v_input.at({b, h, m, 0}), - S * sizeof(T)); - }); + attn_memcpy(past_k_input, past_v_input, past_k_output, past_v_output); } } else { // k,v inputs are already concatenated @@ -657,7 +543,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); - PlainTensor attn_mask; + PlainTensor attn_mask; if (input_num > 3) { // attn_mask if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { @@ -679,10 +565,10 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt L1 = q_input.size(2); S = q_input.size(-1); - PlainTensor present_key, present_value; + PlainTensor present_key, present_value; concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); - ov::intel_cpu::PlainTensor output_emb(outputs[0]); + ov::intel_cpu::PlainTensor output_emb(outputs[0]); bool auto_causal; bool use_attn_mask; @@ -715,7 +601,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt if (L1 > 1) { // multi-token version - kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, has_out_transpose, auto_causal, scale_input); } else { // 1-token version @@ -723,7 +609,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1 // 2, using float will save the repack cost which typically is required for bf16/int8 opt // 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache - kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, beam_table, has_out_transpose, auto_causal, scale_input); } } @@ -740,7 +626,7 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrget_causal(); } else { - const auto node = std::dynamic_pointer_cast(op); + const auto node = std::dynamic_pointer_cast(op); m_config.config = node->get_config(); } } @@ -749,10 +635,14 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; auto rtPrecision = getOriginalInputPrecisionAtPort(0); + auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0); + + bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16; + + auto kvCachePrecision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; NodeConfig config; auto& creatorsMap = BlockedDescCreator::getCommonCreators(); - auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0); config.inConfs.resize(getOriginalInputsNumber()); config.outConfs.resize(getOriginalOutputsNumber()); config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -782,15 +672,15 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3}); config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc( - rtPrecision, getOutputShapeAtPort(1))); + kvCachePrecision, getOutputShapeAtPort(1))); config.outConfs[1].inPlace(orginSDPInputNumber + 0); config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc( - rtPrecision, getOutputShapeAtPort(2))); + kvCachePrecision, getOutputShapeAtPort(2))); config.outConfs[2].inPlace(orginSDPInputNumber + 1); } @@ -801,14 +691,14 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { // may fallback to abcd without inplace if (m_config.config.fuse_concat) { config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); + kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getOutputShapeAtPort(1))); + kvCachePrecision, getOutputShapeAtPort(1))); config.outConfs[1].inPlace(-1); config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getOutputShapeAtPort(2))); + kvCachePrecision, getOutputShapeAtPort(2))); config.outConfs[2].inPlace(-1); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); } @@ -851,8 +741,8 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (!std::dynamic_pointer_cast(op) && - !std::dynamic_pointer_cast(op)) { - errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported"; + !std::dynamic_pointer_cast(op)) { + errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionWithKVCache operation are supported"; return false; } // expect shape of q: [B, H, L, S] @@ -862,7 +752,7 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr(op->get_input_size()); - const auto node = std::dynamic_pointer_cast(op); + const auto node = std::dynamic_pointer_cast(op); if (node) { if (node->get_config().fuse_concat) { orgSDPAInput -= 2; diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 7c08ef99faf1d4..78bc9d4231478f 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -10,7 +10,7 @@ #include #include -#include "transformations/cpu_opset/common/op/sdp.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" namespace ov { namespace intel_cpu { @@ -26,7 +26,7 @@ class ScaledDotProductAttention : public Node { } // pastkv may have zero dimension bool isExecutable() const override { - return true; + return !isInputTensorAtPortEmpty(0) && !isInputTensorAtPortEmpty(1) && !isInputTensorAtPortEmpty(2); } bool needPrepareParams() const override { return false; @@ -47,7 +47,7 @@ class ScaledDotProductAttention : public Node { }; struct Config { - ScaledDotProductAttentionStub::Config config; + ScaledDotProductAttentionWithKVCache::Config config; bool is_concat_inplaced = false; }; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp similarity index 71% rename from src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp rename to src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp index e433d5ad34fb5a..4dc5ba799dd4eb 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp @@ -2,27 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "sdp.hpp" +#include "sdpa.hpp" #include #include "transformations/itt.hpp" -ov::intel_cpu::ScaledDotProductAttentionStub::ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg) +ov::intel_cpu::ScaledDotProductAttentionWithKVCache::ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { constructor_validate_and_infer_types(); } -std::shared_ptr ov::intel_cpu::ScaledDotProductAttentionStub::clone_with_new_inputs( +std::shared_ptr ov::intel_cpu::ScaledDotProductAttentionWithKVCache::clone_with_new_inputs( const ov::OutputVector& new_args) const { - INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_with_new_inputs); + INTERNAL_OP_SCOPE(ScaledDotProductAttentionWithKVCache_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args, m_config); + return std::make_shared(new_args, m_config); } -void ov::intel_cpu::ScaledDotProductAttentionStub::validate_and_infer_types() { - INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_validate_and_infer_types); +void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_types() { + INTERNAL_OP_SCOPE(ScaledDotProductAttentionWithKVCache_validate_and_infer_types); auto input_num = get_input_size(); // [B, H, L1, S] auto q_ps = get_input_partial_shape(0); @@ -45,8 +45,8 @@ void ov::intel_cpu::ScaledDotProductAttentionStub::validate_and_infer_types() { set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps); } -bool ov::intel_cpu::ScaledDotProductAttentionStub::visit_attributes(ov::AttributeVisitor& visitor) { - INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_visit_attributes); +bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(ScaledDotProductAttentionWithKVCache_visit_attributes); visitor.start_structure("config"); visitor.on_attribute("output_BLHxS", m_config.output_BLHxS); visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp similarity index 78% rename from src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp rename to src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp index 7cf45b24bd7368..94406caeab016e 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp @@ -14,11 +14,11 @@ namespace intel_cpu { /// /// \ingroup ov_ops_cpp_api -class ScaledDotProductAttentionStub : public ov::op::Op { +class ScaledDotProductAttentionWithKVCache : public ov::op::Op { public: - OPENVINO_OP("ScaledDotProductAttentionStub", "cpu_plugin_opset"); + OPENVINO_OP("ScaledDotProductAttentionWithKVCache", "cpu_plugin_opset"); - ScaledDotProductAttentionStub() = default; + ScaledDotProductAttentionWithKVCache() = default; struct Config { bool output_BLHxS = false; // true implies that output is [B,L,H*S] @@ -28,7 +28,7 @@ class ScaledDotProductAttentionStub : public ov::op::Op { bool fuse_concat = false; // fuse (concat->sdp) ==> sdp }; - ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg); + ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg); std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor& visitor) override; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp similarity index 94% rename from src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp rename to src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index a1f9dd24ddcb81..683609e968c900 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "stateful_sdp_fusion.hpp" +#include "stateful_sdpa_fusion.hpp" #include #include @@ -17,13 +17,13 @@ #include "itt.hpp" #include "ov_ops/type_relaxed.hpp" -#include "transformations/cpu_opset/common/op/sdp.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" namespace ov { namespace intel_cpu { -StatefulSDPFusion::StatefulSDPFusion() { - MATCHER_SCOPE(StatefulSDPFusion); +StatefulSDPAFusion::StatefulSDPAFusion() { + MATCHER_SCOPE(StatefulSDPAFusion); using namespace ov::pass::pattern; auto past_k = wrap_type(); @@ -91,13 +91,13 @@ StatefulSDPFusion::StatefulSDPFusion() { args[2] = concat_v_node->input_value(1); args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0)); args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0)); - ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; config.is_causal = sdp_node->get_causal(); config.fuse_concat = true; auto old_node = sdp_node; - auto new_node = std::make_shared(args, config); + auto new_node = std::make_shared(args, config); new_node->set_friendly_name(old_node->get_friendly_name()); ov::replace_node(old_node, {new_node->output(0)}); if (assign_cvt_k_node) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp similarity index 64% rename from src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp rename to src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp index 21f5250868164e..7de7e018036baa 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp @@ -8,10 +8,10 @@ namespace ov { namespace intel_cpu { -class StatefulSDPFusion : public ov::pass::MatcherPass { +class StatefulSDPAFusion : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("StatefulSDPFusion", "0"); - StatefulSDPFusion(); + OPENVINO_RTTI("StatefulSDPAFusion", "0"); + StatefulSDPAFusion(); }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index ebc3bc61bffd24..9f2afa638877ef 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -112,7 +112,7 @@ #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" #include "transformations/cpu_opset/common/pass/rope_fusion.hpp" -#include "transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp" +#include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -661,7 +661,7 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice); CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion); - CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPAFusion); postLPTPassManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp index f6441e18a68f47..945e872dedf5cf 100644 --- a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp +++ b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp @@ -4,7 +4,6 @@ #pragma once -#include #include #include @@ -58,26 +57,6 @@ inline void assert_dt(ov::element::Type dt) { OPENVINO_ASSERT(dt == ov::element::f16); } -template -struct data_type_name { - static constexpr const char* value = "?"; -}; - -template <> -struct data_type_name { - static constexpr const char* value = "float"; -}; - -template <> -struct data_type_name { - static constexpr const char* value = "bfloat16"; -}; - -template <> -struct data_type_name { - static constexpr const char* value = "uint8_t"; -}; - template struct precision_of { static constexpr ov::element::Type_t value = ov::element::Type_t::undefined; @@ -110,7 +89,6 @@ struct precision_of { #define PLAINTENSOR_RANK_MAX 8 -template struct PlainTensor { size_t m_strides[PLAINTENSOR_RANK_MAX]; size_t m_dims[PLAINTENSOR_RANK_MAX]; @@ -118,6 +96,8 @@ struct PlainTensor { void* m_ptr = nullptr; size_t m_capacity = 0; bool with_storage = false; + size_t m_element_size = 0; + ov::element::Type_t m_dt = ov::element::Type_t::undefined; MemoryPtr m_mem; // hold memory ptr reference operator bool() const { @@ -149,12 +129,14 @@ struct PlainTensor { } // copy construct (always not take ownership) - PlainTensor
operator=(const PlainTensor
& other) { + PlainTensor operator=(const PlainTensor& other) { OPENVINO_ASSERT(!with_storage); memcpy(&m_strides, &other.m_strides, sizeof(m_strides)); memcpy(&m_dims, &other.m_dims, sizeof(m_dims)); m_rank = other.m_rank; m_ptr = other.m_ptr; + m_dt = other.m_dt; + m_element_size = other.m_element_size; return *this; } @@ -169,23 +151,21 @@ struct PlainTensor { } void reset(MemoryPtr mem) { - const auto& mem_desc = mem->getDesc(); - assert_dt
(mem_desc.getPrecision()); - const auto* desc_ptr = mem_desc.as(); + auto mem_desc = mem->getDescWithType(); // not support block layout - OPENVINO_ASSERT(desc_ptr && desc_ptr->getOrder().size() == mem->getStaticDims().size()); + OPENVINO_ASSERT(mem_desc && mem_desc->getOrder().size() == mem->getStaticDims().size()); m_mem = mem; - VectorDims strides(desc_ptr->getStrides().size()); - const auto& orders = desc_ptr->getOrder(); + VectorDims strides(mem_desc->getStrides().size()); + const auto& orders = mem_desc->getOrder(); for (size_t i = 0; i < orders.size(); i++) { - strides[orders[i]] = desc_ptr->getStrides()[i]; + strides[orders[i]] = mem_desc->getStrides()[i]; } // this reshape_to() can do reshape w/o additional cost - resize(mem->getStaticDims(), reinterpret_cast(mem->getData()), &strides); + resize(mem->getStaticDims(), mem_desc->getPrecision().size(), mem_desc->getPrecision(), mem->getData(), strides.data()); } - ov::element::Type get_precision(void) { - return precision_of
::value; + ov::element::Type get_precision() const { + return m_dt; } struct tensor_index { @@ -223,8 +203,8 @@ struct PlainTensor { } }; - PlainTensor
index(const std::initializer_list& indices) { - PlainTensor
sub_tensor; + PlainTensor index(const std::initializer_list& indices) { + PlainTensor sub_tensor; assert(indices.size() <= m_rank); int i_src = 0; int i_dst = 0; @@ -246,13 +226,15 @@ struct PlainTensor { i_src++; } sub_tensor.m_rank = i_dst; // index may imply squeeze - sub_tensor.m_ptr = reinterpret_cast(reinterpret_cast(m_ptr) + off); + sub_tensor.m_ptr = reinterpret_cast(reinterpret_cast(m_ptr) + off * m_element_size); + sub_tensor.m_dt = m_dt; + sub_tensor.m_element_size = m_element_size; return sub_tensor; } // slice: return a sub-view (w/o ownership/refcount to original data) - PlainTensor
slice(int axis, int start, int end, int step = 1) const { - PlainTensor
sub_tensor; + PlainTensor slice(int axis, int start, int end, int step = 1) const { + PlainTensor sub_tensor; assert(axis >= 0 && static_cast::type>(axis) < m_rank); sub_tensor.m_capacity = 0; @@ -277,8 +259,10 @@ struct PlainTensor { } auto off = start * m_strides[axis]; - auto* data = reinterpret_cast(m_ptr) + off; + auto* data = reinterpret_cast(m_ptr) + off * m_element_size; sub_tensor.m_ptr = reinterpret_cast(data); + sub_tensor.m_dt = m_dt; + sub_tensor.m_element_size = m_element_size; return sub_tensor; } @@ -310,21 +294,22 @@ struct PlainTensor { simplified form is when whole tensor is dense */ - PlainTensor
reshape(const std::vector& target_shape) const { + PlainTensor reshape(const std::vector& target_shape) const { // only valid for dense memory - PlainTensor
new_tensor_view; + PlainTensor new_tensor_view; assert(is_dense()); - //assert(shape_size(target_shape) == shape_size(m_dims)); - new_tensor_view.resize(VectorDims(target_shape), reinterpret_cast(m_ptr)); + new_tensor_view.resize(target_shape, m_element_size, m_dt, m_ptr); return new_tensor_view; } - PlainTensor
permute(const std::vector& order) const { - PlainTensor
new_tensor_view; + PlainTensor permute(const std::vector& order) const { + PlainTensor new_tensor_view; assert(order.size() == m_rank); new_tensor_view.m_capacity = 0; new_tensor_view.m_ptr = m_ptr; new_tensor_view.m_rank = m_rank; + new_tensor_view.m_dt = m_dt; + new_tensor_view.m_element_size = m_element_size; auto it_order = order.begin(); // also should check order has no repeat element for (size_t i = 0; i < m_rank; i++) { @@ -336,19 +321,21 @@ struct PlainTensor { return new_tensor_view; } - void resize(const VectorDims& new_dims, DT* data = nullptr, const VectorDims* strides = nullptr) { + void resize(const VectorDims& new_dims, size_t element_size, ov::element::Type_t dt, void* data = nullptr, const size_t* strides = nullptr) { + m_element_size = element_size; + m_dt = dt; // initialize strides for compact/dense tensor m_rank = new_dims.size(); assert(m_rank <= PLAINTENSOR_RANK_MAX); size_t stride = 1; for (int i = m_rank - 1; i >= 0; i--) { m_dims[i] = new_dims[i]; - m_strides[i] = strides ? (*strides)[i] : stride; + m_strides[i] = strides ? strides[i] : stride; stride *= new_dims[i]; } if (!data) { - auto capacity_new = m_strides[0] * m_dims[0] * sizeof(DT); + auto capacity_new = m_strides[0] * m_dims[0] * m_element_size; if (capacity_new > m_capacity) { if (!with_storage) { throw std::bad_alloc(); @@ -368,11 +355,18 @@ struct PlainTensor { } } + template + void resize(const VectorDims& new_dims, DT* data = nullptr, const size_t* strides = nullptr) { + resize(new_dims, sizeof(DT), precision_of
::value, data, strides); + } + + template DT* data() const { return reinterpret_cast(m_ptr); } // when allow_broadcast is true, index to size-1 dim will always access 0. + template DT& at(const std::initializer_list& index, bool allow_broadcast = false) const { size_t off = 0; auto it = index.begin(); @@ -389,7 +383,8 @@ struct PlainTensor { return reinterpret_cast(m_ptr)[off]; } - PlainTensor
& operator=(const DT& value) { + template + PlainTensor& operator=(const DT& value) { // assign every element to value std::vector index(m_rank, 0); auto* dst = reinterpret_cast(m_ptr); @@ -412,8 +407,9 @@ struct PlainTensor { return *this; } + template DT& operator()(const std::initializer_list& index, bool allow_broadcast = false) const { - return at(index, allow_broadcast); + return at
(index, allow_broadcast); } void assert_dims(const std::initializer_list& expect_dims, bool special_zero = false) const { @@ -452,7 +448,7 @@ struct PlainTensor { return "{empty}"; } std::stringstream ss; - ss << data_type_name
::value << " shape=["; + ss << m_dt << " shape=["; const char* sep = ""; size_t sz = 1; for (size_t i = 0; i < m_rank; i++) { @@ -475,7 +471,6 @@ struct PlainTensor { size_t cur_line_elecnt = 0; size_t cur_row_elecnt = 0; size_t i; - auto* p = reinterpret_cast(m_ptr); for (i = 0; i < sz && max_total_lines > 0; i++) { if ((i % last_dim_size) == 0) { ss << row_id << ":\t\t"; @@ -485,10 +480,20 @@ struct PlainTensor { // display current element if we still have buget if (cur_row_lines_left > 0) { - if (std::is_integral
::value) - ss << static_cast(p[i]) << ","; + if (m_dt == ov::element::Type_t::f32) + ss << reinterpret_cast(m_ptr)[i] << ","; + else if (m_dt == ov::element::Type_t::bf16) + ss << reinterpret_cast(m_ptr)[i] << ","; + else if (m_dt == ov::element::Type_t::f16) + ss << reinterpret_cast(m_ptr)[i] << ","; + else if (m_dt == ov::element::Type_t::i32) + ss << reinterpret_cast(m_ptr)[i] << ","; + else if (m_dt == ov::element::Type_t::i8) + ss << static_cast(reinterpret_cast(m_ptr)[i]) << ","; + else if (m_dt == ov::element::Type_t::u8) + ss << static_cast(reinterpret_cast(m_ptr)[i]) << ","; else - ss << p[i] << ","; + ss << "?,"; cur_line_elecnt++; cur_row_elecnt++; if ((cur_line_elecnt % 16) == 15 || (cur_row_elecnt == last_dim_size)) { @@ -514,12 +519,10 @@ struct PlainTensor { return ss.str(); } - template - friend std::ostream& operator<<(std::ostream& os, const PlainTensor& dt); + friend std::ostream& operator<<(std::ostream& os, const PlainTensor& dt); }; -template -std::ostream& operator<<(std::ostream& os, const PlainTensor& dt) { +inline std::ostream& operator<<(std::ostream& os, const PlainTensor& dt) { os << dt.repr(); return os; } diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp index bf5dac2d822ba6..7f7ea0f30f9997 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp @@ -73,7 +73,7 @@ class ConcatSDPTest : public testing::WithParamInterface, v bool hasShapeOf; std::tie(inType, inputShapes, hasShapeOf) = this->GetParam(); targetDevice = ov::test::utils::DEVICE_CPU; - rel_threshold = 1e-4f; + rel_threshold = 1e-2f; if (inType == ElementType::bf16) { configuration.insert({"ENFORCE_BF16", "YES"}); rel_threshold = 0.01f; @@ -170,7 +170,6 @@ class ConcatSDPTest : public testing::WithParamInterface, v for (auto&& state : inferRequest.query_state()) { state.reset(); } - inferRequest = ov::InferRequest(); } std::vector run_test(std::shared_ptr model) { function = model; diff --git a/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp b/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp index 081d9ddfc7b42d..91673df7d05a86 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp +++ b/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp @@ -22,6 +22,7 @@ #include "ov_models/builders.hpp" #include "nodes/scaled_attn.h" #include "nodes/input.h" +#include "nodes/convert.h" #include "graph.h" #include "cpu_tensor.h" @@ -31,21 +32,25 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) { auto build_graph = [](const ov::Shape& shape, float* qkv_val, float* past_kv_val) { auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, qkv_val); qkv->set_friendly_name("qkv_const"); - auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val); + auto pastkv_f32 = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val); + pastkv_f32->set_friendly_name("pastkv_const_f32"); + auto pastkv = std::make_shared(pastkv_f32, ov::element::f16); pastkv->set_friendly_name("pastkv_const"); // only need a dynamic parameter but its value will not be used auto attn = std::make_shared(ov::element::f32, ov::PartialShape{-1}); attn->set_friendly_name("attn"); - ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; config.fuse_concat = true; config.is_causal = true; - auto sdpa = std::make_shared(ov::OutputVector{qkv, qkv, qkv, attn, pastkv, pastkv}, config); + auto sdpa = std::make_shared(ov::OutputVector{qkv, qkv, qkv, attn, pastkv, pastkv}, config); + auto out_pastk_convert = std::make_shared(sdpa->output(1), ov::element::f32); + auto out_pastv_convert = std::make_shared(sdpa->output(2), ov::element::f32); auto out_qkv = std::make_shared(sdpa->output(0)); out_qkv->set_friendly_name("qkv"); - auto out_pastk = std::make_shared(sdpa->output(1)); + auto out_pastk = std::make_shared(out_pastk_convert); out_pastk->set_friendly_name("pastk"); - auto out_pastv = std::make_shared(sdpa->output(2)); + auto out_pastv = std::make_shared(out_pastv_convert); out_pastv->set_friendly_name("pastv"); std::unordered_set nodes_set; @@ -65,9 +70,12 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) { auto context = std::make_shared(conf, nullptr, nullptr, false); auto qkv_node = std::make_shared(qkv, context); - auto pastkv_node = std::make_shared(pastkv, context); + auto pastkv_f32_node = std::make_shared(pastkv_f32, context); auto attn_node = std::make_shared(attn, context); + auto pastkv_node = std::make_shared(pastkv, context); auto sdpa_node = std::make_shared(sdpa, context); + auto out_pastk_node_convert = std::make_shared(out_pastk_convert, context); + auto out_pastv_node_convert = std::make_shared(out_pastv_convert, context); auto out_qkv_node = std::make_shared(out_qkv, context); auto out_pastk_node = std::make_shared(out_pastk, context); auto out_pastv_node = std::make_shared(out_pastv, context); @@ -76,11 +84,14 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) { add_edge(qkv_node, sdpa_node, 0, 1); add_edge(qkv_node, sdpa_node, 0, 2); add_edge(attn_node, sdpa_node, 0, 3); + add_edge(pastkv_f32_node, pastkv_node, 0, 0); add_edge(pastkv_node, sdpa_node, 0, 4); add_edge(pastkv_node, sdpa_node, 0, 5); add_edge(sdpa_node, out_qkv_node, 0, 0); - add_edge(sdpa_node, out_pastk_node, 1, 0); - add_edge(sdpa_node, out_pastv_node, 2, 0); + add_edge(sdpa_node, out_pastk_node_convert, 1, 0); + add_edge(sdpa_node, out_pastv_node_convert, 2, 0); + add_edge(out_pastk_node_convert, out_pastk_node, 0, 0); + add_edge(out_pastv_node_convert, out_pastv_node, 0, 0); std::vector graph_nodes(nodes_set.begin(), nodes_set.end()); @@ -104,7 +115,7 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) { auto check_graph = [] (Graph& graph, std::map>& expected) { auto& outputNodesMap = graph.GetOutputNodesMap(); auto is_same = [] (float a, float b) { - return std::abs(a - b) < 0.0001f; + return std::abs(a - b) < 0.01f; }; for (auto &outputMap : outputNodesMap) { auto name = outputMap.first; diff --git a/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp b/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp index 6f30e3390f4d37..1ce6263cb2d631 100644 --- a/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp +++ b/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include #include #include @@ -38,9 +38,9 @@ static std::shared_ptr makeSDPA(const ov::PartialShape& inputShape, b pastv = std::make_shared(pastv, element::f32); } if (isRef) { - ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; config.fuse_concat = true; - auto new_node = std::make_shared(OutputVector{q, k, v, pastk, pastv}, config); + auto new_node = std::make_shared(OutputVector{q, k, v, pastk, pastv}, config); sdp = new_node->output(0); concatK = new_node->output(1); concatV = new_node->output(2); @@ -71,7 +71,7 @@ TEST(TransformationTests, StateConcatSDPA) { f = makeSDPA(inputShape); pass::Manager m; m.register_pass(); - m.register_pass(); + m.register_pass(); m.run_passes(f); } //construct ref interaction @@ -92,7 +92,7 @@ TEST(TransformationTests, StateConcatSDPAWithConvert) { f = makeSDPA(inputShape, false, true); pass::Manager m; m.register_pass(); - m.register_pass(); + m.register_pass(); m.run_passes(f); } //construct ref interaction