diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp index eda7a3ae3192a7..94f8e2bed37b30 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp @@ -17,7 +17,14 @@ using namespace dnnl::impl::cpu::x64::matmul; namespace ov { namespace intel_cpu { -BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc, bool b_transposed) +BrgemmKernel::BrgemmKernel(size_t M, + size_t N, + size_t K, + size_t lda, + size_t ldb, + size_t ldc, + bool b_transposed, + ov::element::Type inType) : M(M), K(K), N(N), @@ -25,18 +32,24 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, ldb(ldb), ldc(ldc), b_transposed(b_transposed), - inType(ov::element::bf16) { + inType(inType) { // blocking M M_blk = matmulOptimalM; M_tail = M % M_blk; brgVnniFactor = 4 / inType.size(); - if (!mayiuse(avx512_core_bf16)) - THROW_ERROR("brgemm bf16bf16f32 kernel could only be used above avx512_bf16"); - bool isAMXSupported = mayiuse(avx512_core_amx); - // blocing N + + if (inType != ov::element::bf16 && inType != ov::element::f32) + THROW_ERROR("brgemm kernel only supports bf16, f32"); + bool is_bf16 = inType == ov::element::bf16; + if (is_bf16 && !mayiuse(avx512_core_bf16)) + THROW_ERROR("brgemm bf16 kernel could only be used above avx512_bf16"); + + bool isAMXSupported = is_bf16 && mayiuse(avx512_core_amx); + // blocking N + N_blk = is_bf16 ? 32 : N; N_tail = N % N_blk; - // blocing K + // blocking K K_blk = isAMXSupported ? 32 : K; K_tail = K % K_blk; if (isAMXSupported && K_tail) { @@ -60,7 +73,7 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, brgemmCtx.N = N_; brgemmCtx.K = K_; brgemmCtx.LDA = k ? K_blk : lda; - brgemmCtx.LDB = rnd_up(N, N_blk); // B is copied with bf16 + brgemmCtx.LDB = (is_bf16 || b_transposed) ? rnd_up(N, N_blk) : ldb; // bf16/b_transposed needs copy brgemmCtx.LDC = ldc; brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(inType)); brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(inType)); @@ -83,7 +96,7 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, packedASize = M_blk * rnd_up(K, K_blk) * inType.size(); } - if (brgemmCtx0.is_with_amx || inType == ov::element::bf16) { + if (brgemmCtx0.is_with_amx || inType == ov::element::bf16 || b_transposed) { size_t b_stride = 0; // must set actual stride when stride is not K/N if (b_transposed) { @@ -267,11 +280,11 @@ void BrgemmKernel::copy_buffer_b(void* b, void* scratch_b) { } } -void BrgemmKernel::executeGemmPackedB(bool is_M_tail, void* a, void* repacked_b, void* c, void* wsp, void* scratch_a) { +void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a) { auto ptr_A = reinterpret_cast(a); auto ptr_C = reinterpret_cast(c); auto ptr_scartch_a = reinterpret_cast(scratch_a); - auto ptr_scartch_b = reinterpret_cast(repacked_b); + auto ptr_scartch_b = reinterpret_cast(b); uint8_t* ptr_a_tail = nullptr; size_t brgIdx0 = getBrgIdx(0, 0, 0); @@ -340,8 +353,8 @@ void BrgemmKernel::executeGemm(void* a, void* b, void* c, void* wsp, void* scrat for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { const bool is_M_tail = (M - mb * M_blk < M_blk); auto ptr_a = ptr_A + (mb * M_blk * lda) * inType.size(); - auto ptr_c = ptr_C + (mb * M_blk * ldc) * inType.size(); - executeGemmPackedB(is_M_tail, ptr_a, scratch_b, wsp, ptr_c, scratch_a); + auto ptr_c = ptr_C + (mb * M_blk * ldc) * ov::element::f32.size(); + executeGemm(is_M_tail, ptr_a, scratch_b, wsp, ptr_c, scratch_a); } } void BrgemmKernel::callBrgemm(brgemmCtx& ctx, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp index a0877a41019fba..566448840961a9 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp @@ -21,11 +21,18 @@ class BrgemmKernel { // ldb is the leading dimension for B matrix // ldc is the leading dimension for C matrix // b_transpose indicates wheter B matrix is transposed. - BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc, bool b_transposed = false); + BrgemmKernel(size_t M, + size_t N, + size_t K, + size_t lda, + size_t ldb, + size_t ldc, + bool b_transposed = false, + ov::element::Type inType = ov::element::bf16); // execute all M void executeGemm(void* a, void* b, void* c, void* wsp, void* scratch_a, void* scratch_b); - // execute m_blk - void executeGemmPackedB(bool is_M_tail, void* a, void* repacked_b, void* c, void* wsp, void* scratch_a); + // execute by m_blk + void executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a); void copy_buffer_b(void* b, void* scratch_b); // bytes needed to place scratch buffer a @@ -44,14 +51,13 @@ class BrgemmKernel { private: size_t M = 0, M_blk = 0, M_tail = 0; - size_t K = 0, K_blk = 0, K_tail = 0, N = 0, N_tail = 0; + size_t K = 0, K_blk = 0, K_tail = 0, N = 0, N_blk = 0, N_tail = 0; size_t lda = 0, ldb = 0, ldc = 0; bool b_transposed = false; size_t brgVnniFactor = 0; size_t packedBSize = 0; size_t packedASize = 0; ov::element::Type inType; - static constexpr size_t N_blk = 32; static constexpr size_t MHA_BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; struct brgemmCtx { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 49dfed8e8749ba..f089f785ce7ebf 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -131,7 +131,8 @@ struct MHAKernel { auto q_len = query.size(2); auto head_size = query.size(3); auto kv_len = present_key.size(2); - + auto Hk = present_key.size(1); + size_t h_each_group_len = H / Hk; if (d_scale == 0.0f) d_scale = 1.0f / sqrt(head_size); @@ -150,7 +151,7 @@ struct MHAKernel { if (auto_causal) ncausal = kv_len - q_len + m + 1; for (size_t n = 0; n < ncausal; n++) { - auto* k = &present_key.at({b, h, n, 0}, true); + auto* k = &present_key.at({b, h / h_each_group_len, n, 0}, true); attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale; // apply alibi tensor @@ -159,7 +160,7 @@ struct MHAKernel { // apply attention mask (maybe combined with causal_mask) if (attention_mask) - attn_score[n] += attention_mask.at({b, h, m, n}, true); + attn_score[n] += attention_mask.at({b, h, m, n}, true); // apply causal_mask if (causal_mask) { @@ -181,7 +182,7 @@ struct MHAKernel { // linearly combine value word_vec.assign(head_size, 0.0f); for (size_t n = 0; n < ncausal; n++) { - auto* v = &present_value.at({b, h, n, 0}, true); + auto* v = &present_value.at({b, h / h_each_group_len, n, 0}, true); accumulate(word_vec.data(), v, head_size, attn_score[n]); } @@ -225,6 +226,7 @@ struct MHAKernel { size_t ldb; size_t ldc; bool b_transposed; + ov::element::Type in_type; size_t hash() const { using namespace dnnl::impl; using namespace dnnl::impl::primitive_hashing; @@ -236,11 +238,12 @@ struct MHAKernel { seed = hash_combine(seed, ldb); seed = hash_combine(seed, ldc); seed = hash_combine(seed, b_transposed); + seed = hash_combine(seed, in_type.hash()); return seed; } bool operator==(const brgemmKey& rhs) const { return (rhs.M == M) && (rhs.N == N) && (rhs.K == K) && (rhs.lda == lda) && (rhs.ldb == ldb) && - (rhs.ldc == ldc) && (rhs.b_transposed == b_transposed); + (rhs.ldc == ldc) && (rhs.b_transposed == b_transposed) && (rhs.in_type == in_type); } }; @@ -258,23 +261,26 @@ struct MHAKernel { return dnnl_dims; } - void prepare_multiquery_prim(dnnl::stream strm, - PlainTensor& query, - PlainTensor& present_key, - bool has_out_transpose) { - auto qkv_dt = precision_of::value == ov::element::f32 ? dt::f32 : dt::bf16; - if (qkv_dt != dt::bf16) - OPENVINO_THROW("Brgemm multi-query kernel only supports BF16"); + void prepare_brgemm_prim(dnnl::stream strm, PlainTensor& query, PlainTensor& present_key, bool has_out_transpose) { + auto in_type = precision_of::value; + auto qkv_dt = in_type == ov::element::f32 ? dt::f32 : dt::bf16; auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); auto head_size = query.size(3); auto kv_len = present_key.size(2); auto Hk = present_key.size(1); - brgemmKey qk_key = {q_len, kv_len, head_size, query.stride(2), present_key.stride(2), kv_len, true}; + brgemmKey qk_key = {q_len, kv_len, head_size, query.stride(2), present_key.stride(2), kv_len, true, in_type}; auto builder = [](const brgemmKey& key) -> std::shared_ptr { - return std::make_shared(key.M, key.N, key.K, key.lda, key.ldb, key.ldc, key.b_transposed); + return std::make_shared(key.M, + key.N, + key.K, + key.lda, + key.ldb, + key.ldc, + key.b_transposed, + key.in_type); }; auto cache = this->context->getParamsCache(); @@ -298,7 +304,8 @@ struct MHAKernel { kv_len, present_key.stride(2), static_cast(out_md.get_strides()[ldc_index]), - false}; + false, + in_type}; auto wv_result = cache->getOrCreate(wv_key, builder); if (!wv_result.first) { @@ -315,11 +322,11 @@ struct MHAKernel { // allocate scratch a/b, notice get_scratch_a_size/get_scratch_b_size returns in bytes size_t data_size = sizeof(T); - qk_scratch_a.resize({nthr, qk_gemm_ptr->get_scratch_a_size() / data_size}); - wv_scratch_a.resize({nthr, wv_gemm_ptr->get_scratch_a_size() / data_size}); + qk_scratch_a.resize({nthr, qk_gemm_ptr->get_scratch_a_size() / data_size}); + wv_scratch_a.resize({nthr, wv_gemm_ptr->get_scratch_a_size() / data_size}); - qk_scratch_b.resize({B, Hk, qk_gemm_ptr->get_scratch_b_size() / data_size}); - wv_scratch_b.resize({B, Hk, wv_gemm_ptr->get_scratch_b_size() / data_size}); + qk_scratch_b.resize({B, Hk, qk_gemm_ptr->get_scratch_b_size() / data_size}); + wv_scratch_b.resize({B, Hk, wv_gemm_ptr->get_scratch_b_size() / data_size}); if (!attn_score || attn_md.get_size() > attn_score.get_desc().get_size()) { attn_score = dnnl::memory(attn_md, strm.get_engine()); attn_weight = dnnl::memory(weight_md, strm.get_engine()); @@ -332,53 +339,15 @@ struct MHAKernel { return; } - void prepare_prim(dnnl::stream strm, - PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - size_t B, - size_t H, - size_t Hk, - size_t q_len, - size_t kv_len, - size_t S, - bool has_out_transpose) { - auto qkv_dt = precision_of::value == ov::element::f32 ? dt::f32 : dt::bf16; - dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, query.get_strides()); - dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, present_key.get_strides()); - if (cur_q_md == q_md && cur_k_md == k_md) - return; - - q_md = cur_q_md; - k_md = cur_k_md; - dnnl::memory::desc attn_md(make_dnnl_dims({B, H, q_len, kv_len}), dt::f32, tag::abcd); - k_md = k_md.permute_axes({0, 1, 3, 2}); - auto qk_pd = dnnl::matmul::primitive_desc(strm.get_engine(), q_md, k_md, attn_md); - qk_prim = dnnl::matmul(qk_pd); - - weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd); - v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, present_value.get_strides()); - out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd); - if (has_out_transpose) - out_md = out_md.permute_axes({0, 2, 1, 3}); - auto wv_pd = dnnl::matmul::primitive_desc(strm.get_engine(), weight_md, v_md, out_md); - wv_prim = dnnl::matmul(wv_pd); - - if (!attn_score || attn_md.get_size() > attn_score.get_desc().get_size()) { - attn_score = dnnl::memory(attn_md, strm.get_engine()); - attn_weight = dnnl::memory(weight_md, strm.get_engine()); - } - } - - void exec_multiquery(PlainTensor& query, - PlainTensor& present_key, - PlainTensor& present_value, - const PlainTensor& alibi_mask, - const PlainTensor& attention_mask, - PlainTensor& output_emb, - bool has_out_transpose, - bool auto_causal, - float d_scale = 0.0f) { + void execute_brgemm(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, + bool has_out_transpose, + bool auto_causal, + float d_scale = 0.0f) { const auto B = query.size(0); const auto H = query.size(1); const auto q_len = query.size(2); @@ -388,15 +357,17 @@ struct MHAKernel { size_t h_each_group_len = H / Hk; PlainTensor score, weight; score.resize({B, H, q_len, kv_len}, static_cast(attn_score.get_data_handle())); - weight.resize({B, H, q_len, kv_len}, static_cast(attn_weight.get_data_handle())); + weight.resize({B, H, q_len, kv_len}, static_cast(attn_weight.get_data_handle())); const size_t m_block_size = qk_gemm_ptr->get_mblk_size(); auto m_blocks = (q_len + m_block_size - 1) / m_block_size; + bool is_bf16 = precision_of::value == ov::element::bf16; // packed k, v parallel_for2d(B, Hk, [&](size_t b, size_t h) { - bfloat16* k_ptr = &present_key.at({b, h, 0, 0}); - bfloat16* v_ptr = &present_value.at({b, h, 0, 0}); - qk_gemm_ptr->copy_buffer_b(k_ptr, &qk_scratch_b.at({b, h, 0})); - wv_gemm_ptr->copy_buffer_b(v_ptr, &wv_scratch_b.at({b, h, 0})); + T* k_ptr = &present_key.at({b, h, 0, 0}); + T* v_ptr = &present_value.at({b, h, 0, 0}); + qk_gemm_ptr->copy_buffer_b(k_ptr, &qk_scratch_b.at({b, h, 0})); + if (is_bf16) + wv_gemm_ptr->copy_buffer_b(v_ptr, &wv_scratch_b.at({b, h, 0})); }); // attention @@ -405,14 +376,15 @@ struct MHAKernel { auto m_end = std::min(m_start + m_block_size, q_len); auto m_cnt = m_end - m_start; size_t tid = parallel_get_thread_num(); - bfloat16* q_ptr = &query.at({b, h, m_start, 0}); + T* q_ptr = &query.at({b, h, m_start, 0}); float* c_ptr = &score.at({b, h, m_start, 0}); - qk_gemm_ptr->executeGemmPackedB(m_cnt < m_block_size, - q_ptr, - &qk_scratch_b.at({b, h / h_each_group_len, 0}), - c_ptr, - wsp.data() + tid * wsp_size_per_thread, - qk_scratch_a ? &qk_scratch_a.at({tid, 0}) : nullptr); + T* k_ptr = &qk_scratch_b.at({b, h / h_each_group_len, 0}); + qk_gemm_ptr->executeGemm(m_cnt < m_block_size, + q_ptr, + k_ptr, + c_ptr, + wsp.data() + tid * wsp_size_per_thread, + qk_scratch_a ? &qk_scratch_a.at({tid, 0}) : nullptr); float* alibi_ptr = nullptr; auto alibi_stride = 0; if (alibi_mask) { @@ -439,7 +411,7 @@ struct MHAKernel { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; attn_softmax(&score.at({b, h, m, 0}), - &weight.at({b, h, m, 0}), + &weight.at({b, h, m, 0}), d_scale, alibi_ptr + m * alibi_stride, attn_mask_ptr + m * attn_mask_stride, @@ -450,35 +422,26 @@ struct MHAKernel { precision_of::value, precision_of::value); } - bfloat16* w_ptr = &weight.at({b, h, m_start, 0}); + T* w_ptr = &weight.at({b, h, m_start, 0}); + PlainTensor& sdpa_out = is_bf16 ? fp32_out : output_emb; float* fp32_out_ptr = - has_out_transpose ? &fp32_out.at({b, m_start, h, 0}) : &fp32_out.at({b, h, m_start, 0}); - wv_gemm_ptr->executeGemmPackedB(m_cnt < m_block_size, - w_ptr, - &wv_scratch_b.at({b, h / h_each_group_len, 0}), - fp32_out_ptr, - wsp.data() + tid * wsp_size_per_thread, - wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); + has_out_transpose ? &sdpa_out.at({b, m_start, h, 0}) : &sdpa_out.at({b, h, m_start, 0}); + T* v_ptr = is_bf16 ? &wv_scratch_b.at({b, h / h_each_group_len, 0}) + : &present_value.at({b, h / h_each_group_len, 0, 0}); + wv_gemm_ptr->executeGemm(m_cnt < m_block_size, + w_ptr, + v_ptr, + fp32_out_ptr, + wsp.data() + tid * wsp_size_per_thread, + wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); + if (is_bf16) { + cpu_convert(&fp32_out.at({b, h, m_start, 0}), + &output_emb.at({b, h, m_start, 0}), + ov::element::f32, + ov::element::bf16, + m_cnt * head_size); + } }); - cpu_convert(&fp32_out.at({0, 0, 0, 0}), - &output_emb.at({0, 0, 0, 0}), - ov::element::f32, - ov::element::bf16, - B * H * q_len * head_size); - } - - void exec_qk(dnnl::stream strm, PlainTensor& query, PlainTensor& present_key) { - dnnl::memory q(q_md, strm.get_engine(), query.ptr()); - dnnl::memory k(k_md, strm.get_engine(), present_key.ptr()); - qk_prim.execute(strm, {{DNNL_ARG_SRC, q}, - {DNNL_ARG_WEIGHTS, k}, - {DNNL_ARG_DST, attn_score}}); - } - - void exec_kv(dnnl::stream strm, PlainTensor& present_value, PlainTensor& output_emb) { - dnnl::memory v(v_md, strm.get_engine(), present_value.ptr()); - dnnl::memory out(out_md, strm.get_engine(), output_emb.ptr()); - wv_prim.execute(strm, {{DNNL_ARG_SRC, attn_weight}, {DNNL_ARG_WEIGHTS, v}, {DNNL_ARG_DST, out}}); } PlainTensor causal_mask; @@ -505,61 +468,20 @@ struct MHAKernel { bool has_out_transpose, bool auto_causal, float d_scale = 0.0f) { - auto B = query.size(0); - auto H = query.size(1); - auto q_len = query.size(2); auto head_size = query.size(3); - auto Hk = present_key.size(1); - auto kv_len = present_key.size(2); - bool is_multi_query = H != Hk; if (d_scale == 0.0f) d_scale = 1.0f / sqrt(head_size); - if (is_multi_query) { - prepare_multiquery_prim(strm, query, present_key, has_out_transpose); - exec_multiquery(query, - present_key, - present_value, - alibi_mask, - attention_mask, - output_emb, - has_out_transpose, - auto_causal, - d_scale); - return; - } - prepare_prim(strm, query, present_key, present_value, B, H, Hk, q_len, kv_len, head_size, has_out_transpose); - exec_qk(strm, query, present_key); - - PlainTensor score; - score.resize({B, H, q_len, kv_len}, static_cast(attn_score.get_data_handle())); - PlainTensor weight; - weight.resize({B, H, q_len, kv_len}, static_cast(attn_weight.get_data_handle())); - // softmax - parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t m) { - // apply attention mask & sofmax - auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - uint8_t* attn_mask_ptr = nullptr; - auto attn_mask_stride = 0; - if (attention_mask) { - attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, 0, 0}, true)); - if (attention_mask.size(2) > 1) - attn_mask_stride = attention_mask.stride(2) * sizeof(T); - } - attn_softmax(&score.at({b, h, m, 0}), - &weight.at({b, h, m, 0}), - d_scale, - alibi_mask ? &alibi_mask.at({b, h, m, 0}, true) : nullptr, - attn_mask_ptr + m * attn_mask_stride, - causal_mask ? &causal_mask.at({b, h, m, 0}, true) : nullptr, - select_nfltmax_at_0, - ncausal, - kv_len, - precision_of::value, - precision_of::value); - }); - - exec_kv(strm, present_value, output_emb); + prepare_brgemm_prim(strm, query, present_key, has_out_transpose); + execute_brgemm(query, + present_key, + present_value, + alibi_mask, + attention_mask, + output_emb, + has_out_transpose, + auto_causal, + d_scale); } }; @@ -980,7 +902,7 @@ void ScaledDotProductAttention::createPrimitive() { ScaledDotProductAttentionKey key = {rtPrecision}; auto builder = [&](const ScaledDotProductAttentionKey& key) -> std::shared_ptr { - std::shared_ptr executor; + std::shared_ptr executor = nullptr; if (rtPrecision == ov::element::bf16) { #ifdef OPENVINO_ARCH_X86_64 executor = std::make_shared>(context); @@ -988,8 +910,14 @@ void ScaledDotProductAttention::createPrimitive() { } else { #ifdef OV_CPU_WITH_MLAS executor = std::make_shared>(context); +#elif defined(OPENVINO_ARCH_X86_64) + if (with_cpu_x86_avx512_core()) { + executor = std::make_shared>(context); + } else { + executor = std::make_shared>(context); + } #else - executor = std::make_shared>(context); + executor = std::make_shared>(context); #endif } return executor; @@ -998,7 +926,7 @@ void ScaledDotProductAttention::createPrimitive() { auto cache = context->getParamsCache(); auto result = cache->getOrCreate(key, builder); if (!result.first) { - OPENVINO_THROW("ScaleDotAttention AttentionExecutor creation fails with precision " + rtPrecision.to_string()); + OPENVINO_THROW("ScaledDotProductAttention AttentionExecutor creation fails with precision " + rtPrecision.to_string()); } m_executor = result.first; } diff --git a/src/plugins/intel_cpu/tests/unit/brgemm_executor_test.cpp b/src/plugins/intel_cpu/tests/unit/brgemm_executor_test.cpp index 1ffe7a4a85e919..1179c679fc3973 100644 --- a/src/plugins/intel_cpu/tests/unit/brgemm_executor_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/brgemm_executor_test.cpp @@ -4,38 +4,52 @@ #include #include +#include "common_test_utils/test_common.hpp" #include "nodes/kernels/x64/brgemm_kernel.hpp" #include "openvino/core/parallel.hpp" #include "openvino/runtime/system_conf.hpp" -TEST(BrgemmKernel, simple_gemm_test) { - if (!ov::with_cpu_x86_bfloat16()) - GTEST_SKIP(); +namespace brgemmUnitTest { +class BrgemmKernelTest : public ov::test::TestsCommon, public testing::WithParamInterface { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type rtPrec; + rtPrec = obj.param; + std::ostringstream result; + result << "Prec=" << rtPrec.to_string() << std::endl; + return result.str(); + } +}; + +template +void run_test(ov::element::Type rtPrec) { size_t M = 33; size_t N = 32; size_t K = 33; - ov::intel_cpu::BrgemmKernel gemm(M, N, K, K, N, N, false); - std::vector a_data(M * K, (1.0f/33)); - std::vector b_data(K * N, 4.0f); + ov::intel_cpu::BrgemmKernel gemm(M, N, K, K, N, N, false, rtPrec); size_t nthr = 8; + bool is_bf16 = (rtPrec == ov::element::bf16); + std::vector a_data(M * K, (1.0f/33)); + std::vector b_data(K * N, 4.0f); std::vector c_data(nthr * M * N, 0.0f); std::vector wsp(nthr * 4 * 1024, 0.0f); - std::vector b_scracth(gemm.get_scratch_b_size(), 0.0f); - std::vector a_scracth(gemm.get_scratch_a_size(), 0.0f); - - gemm.copy_buffer_b(b_data.data(), b_scracth.data()); + std::vector b_scracth(gemm.get_scratch_b_size(), 0.0f); + std::vector a_scracth(gemm.get_scratch_a_size(), 0.0f); + if (is_bf16) + gemm.copy_buffer_b(b_data.data(), b_scracth.data()); auto m_block_size = gemm.get_mblk_size(); auto m_blocks = (M + gemm.get_mblk_size() - 1) / m_block_size; + T* b_ptr = is_bf16 ? b_scracth.data() : b_data.data(); ov::parallel_for2d(nthr, m_blocks, [&](size_t i, size_t m_blk) { auto m_start = m_blk * m_block_size; auto m_end = std::min(m_start + m_block_size, M); auto m_cnt = m_end - m_start; - gemm.executeGemmPackedB(m_cnt < m_block_size, - a_data.data() + m_start * K, - b_scracth.data(), - c_data.data() + i * M * N + m_start * N, - wsp.data() + i * 4 * 1024, - a_scracth.data()); + gemm.executeGemm(m_cnt < m_block_size, + a_data.data() + m_start * K, + b_ptr, + c_data.data() + i * M * N + m_start * N, + wsp.data() + i * 4 * 1024, + a_scracth.data()); }); ov::parallel_for(nthr, [&](size_t i){ for (size_t m = 0; m < M; m++) { @@ -51,4 +65,24 @@ TEST(BrgemmKernel, simple_gemm_test) { } } }); -} \ No newline at end of file +} + +TEST_P(BrgemmKernelTest, simpleGemmTest) { + ov::element::Type rtPrec = this->GetParam(); + if (rtPrec == ov::element::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + if (rtPrec == ov::element::f32 && !ov::with_cpu_x86_avx512_core()) + GTEST_SKIP(); + + if (rtPrec == ov::element::bf16) { + run_test(rtPrec); + } else { + run_test(rtPrec); + } +} + +INSTANTIATE_TEST_SUITE_P(BrgemmKernelUnitTest, + BrgemmKernelTest, + ::testing::Values(ov::element::f32, ov::element::bf16), + BrgemmKernelTest::getTestCaseName); +} // namespace brgemmUnitTest