diff --git a/src/cpu/gemm/gemm.cpp b/src/cpu/gemm/gemm.cpp index 2e0e2fb8e2c..ba60de123ea 100644 --- a/src/cpu/gemm/gemm.cpp +++ b/src/cpu/gemm/gemm.cpp @@ -239,7 +239,7 @@ dnnl_status_t gemm_s8s8s32(const char *transa, const char *transb, if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; #if DNNL_X64 && !__BUILD_GEMM_NONE - bool use_jit = mayiuse(avx512_core); + bool use_jit = mayiuse(avx512_core) && __BUILD_GEMM_AVX512; bool use_s8u8 = true && utils::everyone_is(0, *ao, *bo) // so far a requirement && IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41)); @@ -297,7 +297,7 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, bfloat16_t *dummy_bo = nullptr; float *dummy_co = nullptr; - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, (const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B, ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false); diff --git a/src/cpu/gemm/gemm.hpp b/src/cpu/gemm/gemm.hpp index c568648f3f8..f02961ce170 100644 --- a/src/cpu/gemm/gemm.hpp +++ b/src/cpu/gemm/gemm.hpp @@ -30,11 +30,11 @@ #include "cpu/x64/cpu_isa_traits.hpp" // Kernels ISA section for configuring knobs. -#define __BUILD_GEMM_AMX BUILD_GEMM_KERNELS_ALL -#define __BUILD_GEMM_AVX512 __BUILD_GEMM_AMX || BUILD_GEMM_AVX512 -#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2 -#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41 -#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE +#define __BUILD_GEMM_AMX (BUILD_GEMM_KERNELS_ALL) +#define __BUILD_GEMM_AVX512 (__BUILD_GEMM_AMX || BUILD_GEMM_AVX512) +#define __BUILD_GEMM_AVX2 (__BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2) +#define __BUILD_GEMM_SSE41 (__BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41) +#define __BUILD_GEMM_NONE (BUILD_GEMM_KERNELS_NONE) #else #define __BUILD_GEMM_AMX 0 #define __BUILD_GEMM_AVX512 0 diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp index 19b26c94803..d0fb52fa6c3 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp @@ -36,7 +36,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const { void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( int um, int un, int k_idx, int n_idx) { - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) { prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]); offb_ += 16; @@ -46,7 +46,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == 16) || (un < unroll_n_)) { if ((k_idx + m_idx + n_idx) == 0) { prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]); @@ -63,7 +63,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if ((um < unroll_m_) && (m_idx == 0)) { if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0) && (n_idx % 6 == 0)) || ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1) @@ -87,7 +87,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( int um, int un, int k_idx, int n_idx) { - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == 2)) { if (k_idx % 3 == 0) { if (n_idx == 1) { @@ -111,7 +111,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( int k_idx, int n_idx, int m_idx) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_) == 0) && (n_idx == 1)) { @@ -126,7 +126,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == unroll_n_)) { if (((k_idx == 0) && (n_idx % 2 == 1) && (m_idx == 0)) || ((k_idx == 1) && (n_idx == 2) && (m_idx == 0)) @@ -160,7 +160,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( int um, int un, int k_idx, int n_idx) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (um == unroll_m_) { if (n_idx == std::min(1, un - 1)) { if (k_idx == unroll_k_ - 1) @@ -173,7 +173,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( } void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (um < unroll_m_) { prefetchw(ptr[CO2_ + elt_size_ * 0]); prefetchw(ptr[CO2_ + elt_size_ * 8]); @@ -228,7 +228,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { mov(C_, ptr[rsp + get_size_of_abi_save_regs() + C_off]); mov(LDC_, ptr[rsp + get_size_of_abi_save_regs() + LDC_off]); - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_; i++) vpxorq(Xbyak::Zmm(i), Xbyak::Zmm(i), Xbyak::Zmm(i)); @@ -267,7 +267,8 @@ void jit_avx2_kernel_sgemm_kern::generate() { add(AA_, A_); mov(CO1_, C_); - if ((unroll_x == unroll_m_) || (!mayiuse(avx512_core))) + if ((unroll_x == unroll_m_) + || !(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) lea(CO2_, ptr[C_ + LDC_ * 2]); add(C_, unroll_x * elt_size_); @@ -292,12 +293,12 @@ void jit_avx2_kernel_sgemm_kern::generate() { T_NEAR); } - if (!mayiuse(avx512_core)) + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) prefetcht2(ptr[AA_ - addr_off_ * elt_size_]); switch (unroll_x) { case 8: - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastf64x4, @@ -319,7 +320,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; case 4: - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastf32x4, @@ -340,7 +341,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; case 2: - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastsd, @@ -357,7 +358,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { &Xbyak::CodeGenerator::vmovsd); break; case 1: - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastss, @@ -377,7 +378,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; default: - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vmovups, @@ -400,7 +401,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { sub(AA_, -16 * elt_size_); } else { if ((unroll_y != unroll_n_) || (unroll_x <= 4)) { diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp index 766b07bc5c9..c51d429c3e8 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #ifndef CPU_X64_GEMM_F32_JIT_AVX2_KERNEL_SGEMM_KERN_HPP #define CPU_X64_GEMM_F32_JIT_AVX2_KERNEL_SGEMM_KERN_HPP +#include "cpu/gemm/gemm.hpp" + #include "cpu/x64/jit_generator.hpp" #define MAX_UNROLL_M 48 @@ -32,15 +34,18 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_kernel_sgemm_kern); const int elt_size_ = 4; const int elt_size_bin_ = 2; - int nelt_per_vecreg_ = mayiuse(avx512_core) ? 16 : 8; + int nelt_per_vecreg_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 16 : 8; const int unroll_m_reg_ = 3; int unroll_m_ = unroll_m_reg_ * nelt_per_vecreg_; - const int unroll_n_ = mayiuse(avx512_core) ? 8 : 4; + const int unroll_n_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 8 : 4; const int unroll_k_ = 4; const int unroll_k_bin_ = 2; - const int unroll_m_bin_ = mayiuse(avx512_core) ? 6 : 5; - const int second_fetch_ = mayiuse(avx512_core) ? 32 : 34; - unsigned int unroll_n_bin_ = mayiuse(avx512_core) ? 3 : 2; + const int unroll_m_bin_ + = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 6 : 5; + const int second_fetch_ + = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 32 : 34; + unsigned int unroll_n_bin_ + = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 3 : 2; bool beta_zero_; Xbyak::Reg64 M_ = rdi, N_ = rsi, K_ = rdx, A_ = r8, B_ = r9, C_ = r10, @@ -48,15 +53,21 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Reg64 I_ = r12, J_ = r13, AA_ = rcx, KK_ = K_, BO_ = rbp, CO1_ = r14, CO2_ = r15; Xbyak::Reg64 AO_ = rbx, LL_ = rax; - int zmm_a_idx_ = 0, zmm_b_idx_ = mayiuse(avx512_core) ? 6 : 3, - zmm_acc_idx_ = mayiuse(avx512_core) ? 8 : 4; - int nb_zmm_a_ = mayiuse(avx512_core) ? unroll_m_reg_ * 2 : unroll_m_reg_, - nb_zmm_b_ = mayiuse(avx512_core) ? 2 : 1; - - int addr_off_ = mayiuse(avx512_core) ? 128 : 32; - int PREFETCHSIZEB_ = mayiuse(avx512_core) ? (-128 + 16 * 8) : 64; - int PREFETCHSIZEA_ = mayiuse(avx512_core) ? (-128 + 16 * 2) - : (PREFETCHSIZEB_ * 2 + 16); + int zmm_a_idx_ = 0, + zmm_b_idx_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 6 : 3, + zmm_acc_idx_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 8 : 4; + int nb_zmm_a_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 + ? unroll_m_reg_ * 2 + : unroll_m_reg_, + nb_zmm_b_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 2 : 1; + + int addr_off_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 128 : 32; + int PREFETCHSIZEB_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 + ? (-128 + 16 * 8) + : 64; + int PREFETCHSIZEA_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 + ? (-128 + 16 * 2) + : (PREFETCHSIZEB_ * 2 + 16); int off_ = 0, offb_ = 0; int next_acc(int idx, int um, int un) const; @@ -74,10 +85,11 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { void loadA_betweenFMAs(int um, int un, int k_idx, int n_idx, int m_idx, void (Xbyak::CodeGenerator::*aload)( const T_desta &, const T_srca &)) { - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); - if (!(mayiuse(avx512_core) || (um <= 8) || ((um == 16) && (un == 4)))) { + if (!((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) || (um <= 8) + || ((um == 16) && (un == 4)))) { if (n_idx == un - 1) { (this->*aload)(T_reg(zmm_a_idx_ + m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) @@ -100,10 +112,11 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { const T_desta &, const T_srca &)) { int i; - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); - if (mayiuse(avx512_core) || (um <= 8) || ((um == 16) && (un == 4))) { + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) || (um <= 8) + || ((um == 16) && (un == 4))) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { (this->*aload)(T_reg(zmm_a_idx_ + i + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) @@ -130,31 +143,38 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Label K_loop_body_label; int i, j, p, b_idx; - int addb_off = ((!mayiuse(avx512_core)) && (nb_zmm_b_ == 2)) ? 1 : 0; + int addb_off = (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (nb_zmm_b_ == 2)) + ? 1 + : 0; - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); off_ = 0, offb_ = 0; - if (mayiuse(avx512_core)) L_aligned(K_loop_body_label); + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + L_aligned(K_loop_body_label); if (cfetch) prefetchC_beforeKloop(um); - if (!mayiuse(avx512_core)) L_aligned(K_loop_body_label); + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) + L_aligned(K_loop_body_label); for (p = 0; p < unroll_k_; p++) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if ((um == unroll_m_) && (p == unroll_k_ - 1)) { prefetcht2(ptr[AA_ - elt_size_ * 128]); } } for (j = 0; j < un; j++) { - b_idx = mayiuse(avx512_core) ? j % nb_zmm_b_ : p % nb_zmm_b_; + b_idx = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 + ? j % nb_zmm_b_ + : p % nb_zmm_b_; - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == unroll_n_)) { if ((j == un - 1) && (p == unroll_k_ - 1)) sub(BO_, -un * unroll_k_ * elt_size_); @@ -182,9 +202,9 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { prefetchB_beforeBload(um, un, p, j); - if (!mayiuse(avx512_core) && (um == unroll_m_) - && (un == unroll_n_) && (j == un - 1) - && (p == unroll_k_ - 1)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (um == unroll_m_) && (un == unroll_n_) + && (j == un - 1) && (p == unroll_k_ - 1)) { (this->*bload)(T_reg(zmm_b_idx_ + b_idx), ptr[BO_ + elt_size_ @@ -205,14 +225,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if (cfetch) prefetchC_afterBload(um, un, p, j); - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if ((um == unroll_m_) && (p == unroll_k_ - 1) && (j == std::min(un - 1, 3))) lea(AA_, ptr[AA_ + elt_size_ * unroll_n_]); } } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { for (j = un; j < unroll_n_; j++) { if (um < unroll_m_) { if (((p % (nb_zmm_a_ / unroll_m_reg_) == 0) @@ -228,7 +248,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { loadA_after(um, un, p, aload); } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { lea(AO_, ptr[AO_ + um * unroll_k_ * elt_size_]); lea(BO_, ptr[BO_ + un * unroll_k_ * elt_size_]); } else { @@ -261,7 +281,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { T_reg(zmm_b_idx_ + (j % nb_zmm_b_)), T_reg(i + zmm_a_idx_)); - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (i == 0) { if (j % 3 == 0) { prefetcht0(ptr[AO_ @@ -290,17 +310,18 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { - j)]); } - if (mayiuse(avx512_core) && (un < 2)) + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512 && (un < 2)) prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_)]); - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { for (i = un; i < 8; i += 4) { prefetcht0(ptr[AO_ + elt_size_ * (PREFETCHSIZEA_ + off_)]); off_ += 16; } } - if (mayiuse(avx512_core) || (um <= nelt_per_vecreg_)) { + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + || (um <= nelt_per_vecreg_)) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { (this->*aload)(T_reg(zmm_a_idx_ + i), ptr[AO_ @@ -310,7 +331,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { lea(AO_, ptr[AO_ + um * elt_size_]); lea(BO_, ptr[BO_ + un * elt_size_]); } else { @@ -334,14 +355,15 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Label end_K_loop_label, end_main_K_loop_label; Xbyak::Label K_loop_with_prefetch_label, K_loop_with_prefetch_rem_label; - Xbyak::Reg64 A_reg = (mayiuse(avx512_core)) ? AO_ - : ((um == unroll_m_) && (un == unroll_n_)) ? A_ + Xbyak::Reg64 A_reg = (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) ? AO_ + : ((um == unroll_m_) && (un == unroll_n_)) ? A_ : AO_; - if (mayiuse(avx512_core) || (unroll_m_ != um) || (unroll_n_ != un)) + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) || (unroll_m_ != um) + || (unroll_n_ != un)) mov(AO_, A_); - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { nb_zmm_a_ = unroll_m_reg_; nb_zmm_b_ = 1; @@ -366,10 +388,11 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { zmm_acc_idx_ = zmm_b_idx_ + nb_zmm_b_; acc_idx = 0; - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { j = zmm_b_idx_; for (k = 0; k < nb_zmm_b_; k++) { - if (!mayiuse(avx512_core) && (un > 1)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (un > 1)) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -383,14 +406,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } for (k = 0; k < nb_zmm_a_ / unroll_m_reg_; k++) { - if (mayiuse(avx512_core)) + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) j = zmm_a_idx_ + k * unroll_m_reg_; else j = zmm_a_idx_ + k * std::max(1, um / nelt_per_vecreg_); for (i = nelt_per_vecreg_; i <= std::max(um, nelt_per_vecreg_); i += nelt_per_vecreg_) { - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -405,10 +428,11 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { j = zmm_b_idx_; for (k = 0; k < nb_zmm_b_; k++) { - if (!mayiuse(avx512_core) && (un > 1)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (un > 1)) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -421,7 +445,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if (un > 1) { if ((um == unroll_m_) @@ -490,16 +514,17 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { vxorps(T_reg(i), T_reg(i), T_reg(i)); } - if (!((mayiuse(avx512_core) || (unroll_m_ != um) || (unroll_n_ != un)))) + if (!((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) || (unroll_m_ != um) + || (unroll_n_ != un))) mov(AO_, A_); mov(LL_, KK_); sar(LL_, unroll_k_bin_); jle(end_main_K_loop_label, T_NEAR); - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) - && (um == unroll_m_))) { + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + || (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (un == unroll_n_) && (um == unroll_m_))) { sub(LL_, second_fetch_); jle(K_loop_with_prefetch_label, T_NEAR); } @@ -507,26 +532,26 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { k_loop_body( 0, um, un, aload, bload); - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) - && (um == unroll_m_))) { + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + || (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (un == unroll_n_) && (um == unroll_m_))) { L_aligned(K_loop_with_prefetch_label); } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { lea(CO2_, ptr[CO1_ + (nelt_per_vecreg_ - 1) * elt_size_]); add(LL_, un); jle(K_loop_with_prefetch_rem_label, T_NEAR); } - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) - && (um == unroll_m_))) { + if ((mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + || (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (un == unroll_n_) && (um == unroll_m_))) { k_loop_body( 1, um, un, aload, bload); } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { L_aligned(K_loop_with_prefetch_rem_label); add(LL_, second_fetch_ - un); jle(end_main_K_loop_label, T_NEAR); @@ -537,7 +562,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { L_aligned(end_main_K_loop_label); - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((un == unroll_n_) && ((um == 16) || (um == 8))) { prefetcht2(ptr[AA_ - 16 * elt_size_]); } @@ -568,7 +593,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if ((um < unroll_m_) && (um >= nelt_per_vecreg_)) offAA = 32 - (un / 2) * 16; - if (mayiuse(avx512_core)) + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) lea(CO2_, ptr[CO1_ + LDC_]); else { if ((um == nelt_per_vecreg_) && (un == unroll_n_)) { @@ -578,7 +603,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } for (j = 0; j < un; j++) { - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { reg_C = (j == 0) ? CO1_ : CO2_; if (j >= 2) { add(CO2_, LDC_); } } else @@ -588,7 +613,8 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if (!is_beta_zero) { if (sepload) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { - if (!mayiuse(avx512_core) && (j % 2 == 1)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (j % 2 == 1)) { (this->*sload)(vec_reg_t(i), ptr[reg_C + LDC_ + elt_size_ * i @@ -614,7 +640,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if (j > 0) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; @@ -623,7 +649,8 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { // store accumulated value in C_ for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { - if (!mayiuse(avx512_core) && (j % 2 == 1)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + && (j % 2 == 1)) { (this->*store)(ptr[reg_C + LDC_ + elt_size_ * i * nelt_per_vecreg_], vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); @@ -632,21 +659,21 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { ptr[reg_C + elt_size_ * i * nelt_per_vecreg_], vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); } - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { vpxorq(vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_), vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_), vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); } } - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == 1)) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; } } - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if (j == std::min(1, un - 1)) { if (j == 0) add(CO1_, LDC_); @@ -662,9 +689,10 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) lea(CO1_, ptr[CO2_ + LDC_]); + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + lea(CO1_, ptr[CO2_ + LDC_]); - if (!mayiuse(avx512_core)) { + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um >= nelt_per_vecreg_) && (un < unroll_n_)) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; diff --git a/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp b/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp index 3f06f576067..007bc74fc2e 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -26,6 +26,7 @@ #include "cpu/gemm/f32/gemm_utils_f32.hpp" #include "cpu/gemm/f32/ref_gemm_f32.hpp" +#include "cpu/gemm/gemm.hpp" #include "cpu/gemm/gemm_msan_unpoison.hpp" #include "cpu/x64/jit_generator.hpp" @@ -1626,7 +1627,7 @@ dnnl_status_t sgemm_nocopy_driver(const char *transa, const char *transb, if (utils::any_null(ker_bn, ker_b1, ker_b0)) return dnnl_runtime_error; dim_t BM = 4032, BN, BK; - if (mayiuse(avx512_core)) { + if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { BN = isTransA ? 384 : 64; BK = 384; } else { diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp index 66600803fad..6c675d189d3 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp @@ -22,6 +22,9 @@ #include "common/utils.hpp" #include "cpu/platform.hpp" + +#include "cpu/gemm/gemm.hpp" + #include "cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.hpp" #include "cpu/x64/gemm/gemm_info.hpp" #include "cpu/x64/jit_generator.hpp" @@ -688,7 +691,8 @@ dnnl_status_t jump_to_gemm_smalln_tn( const gemm_info_t *arg) { if ((arg->n < 16 && arg->n > 1 && arg->transa == do_trans && arg->transb != do_trans) - && mayiuse(avx512_core) && arg->co == nullptr) { + && mayiuse(avx512_core) && __BUILD_GEMM_AVX512 + && arg->co == nullptr) { auto transa_char = (arg->transa != do_trans) ? "N" : "T"; auto transb_char = (arg->transb != do_trans) ? "N" : "T"; return jit_avx512_core_gemm_smalln_tn_f32(transa_char, transb_char, diff --git a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp index 9c506950ce4..cb195f55006 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp @@ -19,6 +19,8 @@ #include "common/math_utils.hpp" #include "common/utils.hpp" +#include "cpu/gemm/gemm.hpp" + #include "cpu/x64/cpu_isa_traits.hpp" #include "cpu/x64/jit_generator.hpp" @@ -313,10 +315,10 @@ void jit_sse41_gemv_n_f32_kern::generate() { // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy) jit_sse41_gemv_n_f32_kern::jit_sse41_gemv_n_f32_kern(void) : jit_generator(jit_name()) - , has_avx512_(mayiuse(avx512_core)) - , has_avx2_(mayiuse(avx2)) - , has_avx_(mayiuse(avx)) - , has_sse41_(mayiuse(sse41)) + , has_avx512_(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) + , has_avx2_(mayiuse(avx2) && __BUILD_GEMM_AVX2) + , has_avx_(mayiuse(avx) && __BUILD_GEMM_AVX2) + , has_sse41_(mayiuse(sse41) && __BUILD_GEMM_SSE41) , arg_lda_(0) , arg_x_(0) , arg_incx_(0) diff --git a/src/cpu/x64/gemm/gemm_driver.cpp b/src/cpu/x64/gemm/gemm_driver.cpp index 4ca5685e989..dae0d417f46 100644 --- a/src/cpu/x64/gemm/gemm_driver.cpp +++ b/src/cpu/x64/gemm/gemm_driver.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,14 +74,25 @@ struct alignas(64) gemm_per_thread_t { template int get_vector_length() { - int v_bytes; + int v_bytes = 0; - if (mayiuse(avx512_core)) + if (false) { + //dummy if +#if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { v_bytes = cpu_isa_traits::vlen; - else if (mayiuse(avx)) +#endif +#if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { v_bytes = cpu_isa_traits::vlen; - else +#endif +#if __BUILD_GEMM_SSE41 + } else if (mayiuse(sse41)) { v_bytes = cpu_isa_traits::vlen; +#endif + } else { + assert(!"not supposed to be reached."); + } return v_bytes / sizeof(T); } @@ -391,7 +402,7 @@ void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_f32 = data_traits::data_type == data_type::f32; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; dim_t m_stk = col_offset_ws ? 1 : m; dim_t n_stk = row_offset_ws ? 1 : n; @@ -538,8 +549,9 @@ static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + + bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; + bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_amx = is_int8_amx || is_bf16_amx; const std::shared_ptr &a_packed = arg->a_packed; @@ -816,10 +828,10 @@ static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); - bool is_amx = is_int8_amx || is_bf16_amx; + bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; + bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; + bool is_amx = is_int8_amx || is_bf16_amx; // B buffer needs to be large due to zero-padding. if (is_amx) b_buf_nelems @@ -1040,25 +1052,32 @@ static inline bool nocopy_checker( if (data_traits::data_type != data_type::f32) return false; - if (!mayiuse(avx)) return false; + if (!(mayiuse(avx) && __BUILD_GEMM_AVX2)) return false; if (arg->force_nocopy) return true; auto m = arg->m, n = arg->n, k = arg->k; + UNUSED(m), UNUSED(n), UNUSED(k); auto lda = arg->lda, ldb = arg->ldb, ldc = arg->ldc; + UNUSED(lda), UNUSED(ldb), UNUSED(ldc); auto transa = arg->transa, transb = arg->transb; + UNUSED(transa), UNUSED(transb); auto packing = arg->packing; if (packing != pack_type::none) ldc = 64; - if (arg->a_packed || arg->b_packed) - return false; + if (arg->a_packed || arg->b_packed) return false; +#if __BUILD_GEMM_AVX512 else if (mayiuse(avx512_core)) return nocopy_checker_avx512( nthr, transa, transb, m, n, k, lda, ldb, ldc); +#endif +#if __BUILD_GEMM_AVX2 else return nocopy_checker_avx2( nthr, transa, transb, m, n, k, lda, ldb, ldc); +#endif + return false; } template @@ -1091,26 +1110,32 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, bool condition_2D_bsrc = false; if (isSgemm) { // If m is large and n is small then do 1D partitioning for AVX2. - if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN)) + if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) && n <= N2D_MAX + && (m >= nthrs * M2D_MIN)) { condition_2D_bsrc = false; - else + } else { condition_2D_bsrc = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2)) && (m >= 2 * M2D_MIN); + } } else { - int scale = mayiuse(avx512_core) ? nthrs : 20; + int scale = (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) ? nthrs : 20; condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n); } // TODO Check if we should use k-partitioning. int condition_1D_copya = false; - if (mayiuse(avx512_core)) { + if (false) { + // dummy if +#if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68; if (m >= 1000 && (n >= nthrs * thresh)) { condition_2D_bsrc = false; condition_1D_copya = true; } +#endif } else { if (m >= 1000 && n >= 4000) { condition_2D_bsrc = false; @@ -1123,7 +1148,9 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, // TODO: the reasons seems to be in copy_sum_bx routines. At least, // after simple optimization of copy_sum_ax for avx512, similar // restriction on offset B became unnecessary. Revisit. - if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) { + if (is_int8 && arg->ao != 0 + && (arg->bo != 0 + || (mayiuse(avx512_core) && __BUILD_GEMM_AVX512))) { condition_2D_bsrc = false; condition_1D_copya = true; } @@ -1166,7 +1193,7 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, } else if ((n <= 64 || n >= 256)) { while (((nthrs_n > 1) && (n / nthrs_n < arg->un) && (m / nthrs_m >= 2 * arg->um) - && mayiuse(avx512_core)) + && mayiuse(avx512_core) && __BUILD_GEMM_AVX512) || ((nthrs_n % 2 == 0) && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2) @@ -1294,7 +1321,8 @@ static inline void set_thread_opts_pack(int nthrs, choose_k_blocking(); // Choose m/n blocking. - auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um; + auto min_mblk = (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) ? (MBLK / 2) + : arg->um; min_mblk = do_m_blocking ? min_mblk : m; min_mblk = do_m_blocking_only ? arg->um : min_mblk; auto min_nblk = do_n_blocking ? NBLK / 2 : n; @@ -1348,9 +1376,13 @@ static inline int set_thread_opts(int nthrs, int nthrs_spawn, dim_t BK = 0; auto m = arg->m, n = arg->n, k = arg->k; - if (mayiuse(avx512_core)) { + if (false) { + // dummy if +#if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs, &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK); +#endif } else { cpu::gemm_utils::calc_nthr_nocopy_avx(m, n, k, nthrs, &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK); @@ -1422,10 +1454,9 @@ static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; + bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_amx = is_int8_amx || is_bf16_amx; - const std::shared_ptr &a_packed = arg->a_packed; // Scaling C matrix. @@ -1579,9 +1610,9 @@ static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) { const bool is_f32 = data_traits::data_type == data_type::f32; - const bool is_avx512 = mayiuse(avx512_core); - const bool is_avx = mayiuse(avx); - const bool is_only_avx2 = mayiuse(avx2) && !is_avx512; + const bool is_avx512 = mayiuse(avx512_core) && __BUILD_GEMM_AVX512; + const bool is_avx = mayiuse(avx) && __BUILD_GEMM_AVX2; + const bool is_only_avx2 = mayiuse(avx2) && __BUILD_GEMM_AVX2 && !is_avx512; // Some sgemm cases still benefit from using all threads. const bool use_all_threads = is_f32 && n > 50 @@ -1668,15 +1699,17 @@ static dnnl_status_t call_no_copy_sgemm( auto transb_char = (arg->transb != do_trans) ? "N" : "T"; #endif - if (mayiuse(avx512_core)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, &arg->ldc, (float *)arg->co); #endif - } else { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, @@ -1942,8 +1975,10 @@ static dnnl_status_t gemm_threading_driver( == data_type::f32); assert(arg->packing == pack_type::none); - if (mayiuse(avx512_core)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { thread_arg[ithr].result = avx512_common_gemm_f32:: sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -1953,8 +1988,8 @@ static dnnl_status_t gemm_threading_driver( &beta_eff, (float *)c_eff, ldc_eff, nullptr); #endif - } else { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { thread_arg[ithr].result = avx_gemm_f32::sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -2016,25 +2051,33 @@ dnnl_status_t gemm_driver(const char *transA, const char *transB, data_traits::data_type, data_type::s8, data_type::u8); MAYBE_UNUSED(is_int8); +#if __BUILD_GEMM_AVX512 // gemm_driver supports bfloat16 gemm for Intel AVX512 and // Intel AVX512 BF16. assert(IMPLICATION(data_traits::data_type == data_type::bf16, mayiuse(avx512_core) && !force_nocopy)); +#endif +#if __BUILD_GEMM_SSE41 // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX, // Intel SSE4.1 and Intel DL Boost. assert(IMPLICATION(is_int8, mayiuse(sse41))); +#endif +#if __BUILD_GEMM_SSE41 // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX, // and Intel SSE4.1 assert(IMPLICATION( data_traits::data_type == data_type::f32, mayiuse(sse41))); +#endif // 8-bit integer gemm doesn't support nocopy kernels. assert(IMPLICATION(is_int8, !force_nocopy)); +#if __BUILD_GEMM_AVX2 // gemm_driver can only dispatch nocopy for avx and above. assert(IMPLICATION(force_nocopy, mayiuse(avx))); +#endif gemm_info_t args(transA, transB, offsetC, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy, diff --git a/src/cpu/x64/gemm/gemm_info.cpp b/src/cpu/x64/gemm/gemm_info.cpp index acab57ca9eb..cd227f30306 100644 --- a/src/cpu/x64/gemm/gemm_info.cpp +++ b/src/cpu/x64/gemm/gemm_info.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,7 +76,7 @@ void prepare_bo(int32_t &bo_gemm_info, const uint8_t *bo_orig) { template <> void prepare_bo(int32_t &bo_gemm_info, const int8_t *bo_orig) { int bo_s32 = bo_orig ? *bo_orig : 0; - if (!mayiuse(avx512_core_amx)) bo_s32 += 128; + if (!mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX) bo_s32 += 128; bo_gemm_info = bo_s32; } @@ -160,7 +160,8 @@ gemm_info_t::gemm_info_t(const char *transA, const char *transB, // Copy-based sgemm doesn't support force-nocopy for ISAs older // than Intel AVX. - this->force_nocopy = is_sgemm && force_nocopy && mayiuse(avx); + this->force_nocopy + = is_sgemm && force_nocopy && mayiuse(avx) && __BUILD_GEMM_AVX2; if (!this->force_nocopy || is_gemv) { this->jit_init(); } } @@ -208,19 +209,24 @@ template void gemm_info_t::jit_init(void) { bool use_bf16_ymm = false; + UNUSED(use_bf16_ymm); // TODO: Add dispatching for 1-fma SKUs with support to bf16 // instructions for AMX kernel. { constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - const bool max_isa_supports_bf16_ymm - = mayiuse(avx512_core_bf16_ymm) && !mayiuse(avx512_core_amx); + const bool max_isa_supports_bf16_ymm = mayiuse(avx512_core_bf16_ymm) + && __BUILD_GEMM_AVX512 + && !(mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX); use_bf16_ymm = is_bf16 && max_isa_supports_bf16_ymm; } switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core_amx)) { + if (false) { + // dummy if +#if __BUILD_GEMM_AMX + } else if (mayiuse(avx512_core_amx)) { this->um = 32; this->un = 32; this->uk = 64; @@ -231,6 +237,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 0; this->blocking_small_k = 0; this->bn_small_k = 0; +#endif +#if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { this->um = 48; this->un = 8; @@ -242,6 +250,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 384; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { this->um = mayiuse(avx2_vnni) ? 24 : 16; this->un = 4; @@ -253,6 +263,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { this->um = 16; this->un = 2; @@ -264,6 +276,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { this->um = 16; this->un = 2; @@ -275,11 +289,15 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif } break; case data_type::bf16: - if (mayiuse(avx512_core_amx)) { + if (false) { + // dummy if +#if __BUILD_GEMM_AMX + } else if (mayiuse(avx512_core_amx)) { this->um = 32; this->un = 32; this->uk = 32; @@ -290,6 +308,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 0; this->blocking_small_k = 0; this->bn_small_k = 0; +#endif +#if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { this->um = use_bf16_ymm ? 24 : 48; this->un = 8; @@ -301,11 +321,15 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 384; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif } break; case data_type::f32: - if (mayiuse(avx512_core)) { + if (false) { + // dummy if +#if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { this->um = 48; this->un = 8; this->uk = 1; @@ -316,6 +340,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 384; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { this->um = 24; this->un = 4; @@ -327,6 +353,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { this->um = 16; this->un = 4; @@ -338,6 +366,8 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif +#if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { this->um = 8; this->un = 4; @@ -349,6 +379,7 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 256; this->blocking_small_k = 48; this->bn_small_k = 24; +#endif } break; default: assert(!"unsupported data type!"); @@ -360,14 +391,15 @@ void gemm_info_t::jit_init(void) { static std::once_flag initialized; static std::atomic st(dnnl_success); std::call_once(initialized, [&, um] { -#if __BUILD_GEMM_AVX512 const bool b_is_s8 = data_traits::data_type == data_type::s8; -#endif + UNUSED(b_is_s8); constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx + = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; + bool is_bf16_amx + = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_amx = is_int8_amx || is_bf16_amx; static maybe_unique_ptr copy_a[2][2] = {{nullptr}}; @@ -375,8 +407,10 @@ void gemm_info_t::jit_init(void) { switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(amx_int8)) { + if (false) { + // dummy if #if __BUILD_GEMM_AMX + } else if (mayiuse(amx_int8)) { for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -387,8 +421,8 @@ void gemm_info_t::jit_init(void) { false, isTrans, sizeof(b_t))); } #endif - } else if (mayiuse(avx512_core)) { #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { copy_a[no_trans][no_sum].reset( new jit_avx512_core_u8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -409,8 +443,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][do_sum].reset( new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); #endif - } else if (mayiuse(avx2_vnni)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx2_vnni)) { copy_a[no_trans][no_sum].reset( new jit_avx2_vnni_u8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -431,8 +465,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][do_sum].reset( new jit_avx2_vnni_u8_copy_sum_bt_kern()); #endif - } else if (mayiuse(avx2)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx2)) { copy_a[no_trans][no_sum].reset( new jit_avx2_u8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -453,8 +487,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][do_sum].reset( new jit_avx2_u8_copy_sum_bt_kern()); #endif - } else if (mayiuse(avx)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { copy_a[no_trans][no_sum].reset( new jit_avx_u8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -475,8 +509,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][do_sum].reset( new jit_avx_u8_copy_sum_bt_kern()); #endif - } else if (mayiuse(sse41)) { #if __BUILD_GEMM_SSE41 + } else if (mayiuse(sse41)) { copy_a[no_trans][no_sum].reset( new jit_sse41_u8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -501,8 +535,10 @@ void gemm_info_t::jit_init(void) { break; case data_type::bf16: - if (mayiuse(amx_bf16)) { + if (false) { + // dummy if #if __BUILD_GEMM_AMX + } else if (mayiuse(amx_bf16)) { for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -513,8 +549,8 @@ void gemm_info_t::jit_init(void) { false, isTrans, sizeof(b_t))); } #endif - } else if (mayiuse(avx512_core) && !use_bf16_ymm) { #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core) && !use_bf16_ymm) { copy_a[no_trans][no_sum].reset( new jit_avx512_core_s16_48x8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -525,8 +561,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][no_sum].reset( new jit_avx512_core_s16_48x8_copy_bt_kern()); #endif - } else if (mayiuse(avx512_core) && use_bf16_ymm) { #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core) && use_bf16_ymm) { copy_a[no_trans][no_sum].reset( new jit_avx512_core_s16_24x8_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -541,8 +577,10 @@ void gemm_info_t::jit_init(void) { break; case data_type::f32: - if (mayiuse(avx512_core)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { copy_a[no_trans][no_sum].reset( new jit_avx512_core_f32_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -553,8 +591,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][no_sum].reset( new jit_avx512_core_f32_copy_bt_kern()); #endif - } else if (mayiuse(avx2)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx2)) { copy_a[no_trans][no_sum].reset( new jit_avx2_f32_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -565,8 +603,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][no_sum].reset( new jit_avx2_f32_copy_bt_kern()); #endif - } else if (mayiuse(avx)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { copy_a[no_trans][no_sum].reset( new jit_avx_f32_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -577,8 +615,8 @@ void gemm_info_t::jit_init(void) { copy_b[do_trans][no_sum].reset( new jit_avx_f32_copy_bt_kern()); #endif +#if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { -#if __BUILD_GEMM_AVX2 copy_a[no_trans][no_sum].reset( new jit_sse41_f32_copy_an_kern()); copy_a[do_trans][no_sum].reset( @@ -595,26 +633,29 @@ void gemm_info_t::jit_init(void) { default: break; } -#if __BUILD_GEMM_AMX constexpr bool is_a_s8 = data_traits::data_type == data_type::s8; constexpr bool is_b_s8 = data_traits::data_type == data_type::s8; constexpr bool is_c_s32 = data_traits::data_type == data_type::s32; -#endif + UNUSED(is_a_s8); + UNUSED(is_b_s8); + UNUSED(is_c_s32); static maybe_unique_ptr kernel[2][2][2][2] = {{{{nullptr}}}}; switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core_amx)) { + if (false) { + // dummy if #if __BUILD_GEMM_AMX + } else if (mayiuse(avx512_core_amx)) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( new jit_avx512_core_amx_gemm_kern( is_a_s8, is_b_s8, is_c_s32, isBeta0)); } #endif - } else if (mayiuse(avx512_core)) { #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { for (int isBeta0 : {no_beta0, do_beta0}) for (int doColSum : {no_sum, do_sum}) for (int doRowSum : {no_sum, do_sum}) { @@ -623,8 +664,8 @@ void gemm_info_t::jit_init(void) { isBeta0, doColSum, doRowSum)); } #endif - } else if (mayiuse(avx2)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx2)) { for (int isBeta0 : {no_beta0, do_beta0}) for (int doColSum : {no_sum, do_sum}) for (int doRowSum : {no_sum, do_sum}) { @@ -634,8 +675,8 @@ void gemm_info_t::jit_init(void) { um)); } #endif - } else if (mayiuse(avx)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( new jit_avx_kernel_gemm_s8u8s32_kern()); kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( @@ -654,8 +695,8 @@ void gemm_info_t::jit_init(void) { kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); #endif - } else if (mayiuse(sse41)) { #if __BUILD_GEMM_SSE41 + } else if (mayiuse(sse41)) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( new jit_sse41_kernel_gemm_s8u8s32_kern()); kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( @@ -678,16 +719,18 @@ void gemm_info_t::jit_init(void) { break; case data_type::bf16: - if (mayiuse(avx512_core_amx)) { + if (false) { + // dummy if #if __BUILD_GEMM_AMX + } else if (mayiuse(avx512_core_amx)) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( new jit_avx512_core_amx_gemm_kern( is_a_s8, is_b_s8, is_c_s32, isBeta0)); } #endif - } else if (mayiuse(avx512_core)) { #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { for (int isBeta0 : {no_beta0, do_beta0}) for (int isAlpha1 : {no_alpha1, do_alpha1}) { kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( @@ -699,22 +742,24 @@ void gemm_info_t::jit_init(void) { break; case data_type::f32: - if (mayiuse(avx2)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx2)) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( new jit_avx2_kernel_sgemm_kern(isBeta0)); } #endif - } else if (mayiuse(avx)) { #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( new jit_avx_kernel_sgemm_kern()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( new jit_avx_kernel_b0_sgemm_kern()); #endif - } else if (mayiuse(sse41)) { #if __BUILD_GEMM_SSE41 + } else if (mayiuse(sse41)) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( new jit_sse41_kernel_sgemm_kern()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( @@ -732,8 +777,10 @@ void gemm_info_t::jit_init(void) { static maybe_unique_ptr gemv_u8s8s32_kernel = nullptr; switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { gemv_s8s8s32_kernel.reset( new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); gemv_s8u8s32_kernel.reset( @@ -745,8 +792,10 @@ void gemm_info_t::jit_init(void) { break; case data_type::bf16: - if (mayiuse(avx512_core)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX512 + } else if (mayiuse(avx512_core)) { for (int isTrans : {no_trans, do_trans}) gemv_kernel[isTrans].reset( new jit_avx512_core_gemv_bf16bf16f32_kern( @@ -756,14 +805,16 @@ void gemm_info_t::jit_init(void) { break; case data_type::f32: - if (mayiuse(avx)) { + if (false) { + // dummy if #if __BUILD_GEMM_AVX2 + } else if (mayiuse(avx)) { gemv_kernel[no_trans].reset( new jit_sse41_gemv_n_f32_kern()); gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); #endif - } else if (mayiuse(sse41)) { #if __BUILD_GEMM_SSE41 + } else if (mayiuse(sse41)) { gemv_kernel[no_trans].reset( new jit_sse41_gemv_n_f32_kern()); gemv_kernel[do_trans].reset( @@ -816,6 +867,7 @@ void gemm_info_t::jit_init(void) { = (gemm_fptr_t)p_kernel->jit_ker(); } } + // Override compute kernel table with AMX kernels if (is_amx) { // AMX compute kernels don't support alpha scaling, row-offset or @@ -924,14 +976,17 @@ bool gemm_info_t::hasKernels(void) { if (!this->copyA || !this->copyB) return false; +#if __BUILD_GEMM_AVX512 if (mayiuse(avx512_core)) if (!this->gemv_s8u8s32_kernel || !this->gemv_u8s8s32_kernel || !this->gemv_s8s8s32_kernel) return false; +#endif } break; case data_type::bf16: +#if __BUILD_GEMM_AVX512 if (mayiuse(avx512_core)) { for (int isBeta0 : {no_beta0, do_beta0}) if (!this->kernel[isBeta0][no_sum][no_sum]) return false; @@ -941,6 +996,7 @@ bool gemm_info_t::hasKernels(void) { for (int isTrans : {no_trans, do_trans}) if (!this->gemv_kernel[isTrans]) return false; } +#endif break; case data_type::f32: