Skip to content

Commit

Permalink
cpu: x64: gemm: update implementation for GEMM_ISA macro
Browse files Browse the repository at this point in the history
Macros should include mayiuse under their body as systems supporting
features will return `true` for `mayiuse` calls, but kernels will be
missing and logic will be skewed.
  • Loading branch information
dzarukin committed Jul 12, 2024
1 parent 604f27f commit a4e21bd
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 174 deletions.
4 changes: 2 additions & 2 deletions src/cpu/gemm/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ dnnl_status_t gemm_s8s8s32(const char *transa, const char *transb,
if (*M == 0 || *N == 0 || *K == 0) return dnnl_success;

#if DNNL_X64 && !__BUILD_GEMM_NONE
bool use_jit = mayiuse(avx512_core);
bool use_jit = mayiuse(avx512_core) && __BUILD_GEMM_AVX512;
bool use_s8u8 = true
&& utils::everyone_is(0, *ao, *bo) // so far a requirement
&& IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41));
Expand Down Expand Up @@ -297,7 +297,7 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
bfloat16_t *dummy_bo = nullptr;
float *dummy_co = nullptr;

if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha,
(const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B,
ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false);
Expand Down
10 changes: 5 additions & 5 deletions src/cpu/gemm/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
#include "cpu/x64/cpu_isa_traits.hpp"

// Kernels ISA section for configuring knobs.
#define __BUILD_GEMM_AMX BUILD_GEMM_KERNELS_ALL
#define __BUILD_GEMM_AVX512 __BUILD_GEMM_AMX || BUILD_GEMM_AVX512
#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2
#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41
#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE
#define __BUILD_GEMM_AMX (BUILD_GEMM_KERNELS_ALL)
#define __BUILD_GEMM_AVX512 (__BUILD_GEMM_AMX || BUILD_GEMM_AVX512)
#define __BUILD_GEMM_AVX2 (__BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2)
#define __BUILD_GEMM_SSE41 (__BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41)
#define __BUILD_GEMM_NONE (BUILD_GEMM_KERNELS_NONE)
#else
#define __BUILD_GEMM_AMX 0
#define __BUILD_GEMM_AVX512 0
Expand Down
35 changes: 18 additions & 17 deletions src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {

void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(
int um, int un, int k_idx, int n_idx) {
if (!mayiuse(avx512_core)) {
if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) {
if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) {
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
offb_ += 16;
Expand All @@ -46,7 +46,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(

void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (!mayiuse(avx512_core)) {
if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) {
if ((um == 16) || (un < unroll_n_)) {
if ((k_idx + m_idx + n_idx) == 0) {
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
Expand All @@ -63,7 +63,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
if ((um < unroll_m_) && (m_idx == 0)) {
if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0) && (n_idx % 6 == 0))
|| ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1)
Expand All @@ -87,7 +87,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(
int um, int un, int k_idx, int n_idx) {
if (!mayiuse(avx512_core)) {
if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) {
if ((um == unroll_m_) && (un == 2)) {
if (k_idx % 3 == 0) {
if (n_idx == 1) {
Expand All @@ -111,7 +111,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(

void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(
int k_idx, int n_idx, int m_idx) {
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_)
== 0)
&& (n_idx == 1)) {
Expand All @@ -126,7 +126,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (!mayiuse(avx512_core)) {
if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) {
if ((um == unroll_m_) && (un == unroll_n_)) {
if (((k_idx == 0) && (n_idx % 2 == 1) && (m_idx == 0))
|| ((k_idx == 1) && (n_idx == 2) && (m_idx == 0))
Expand Down Expand Up @@ -160,7 +160,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(

void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
int um, int un, int k_idx, int n_idx) {
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
if (um == unroll_m_) {
if (n_idx == std::min(1, un - 1)) {
if (k_idx == unroll_k_ - 1)
Expand All @@ -173,7 +173,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
}

void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) {
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
if (um < unroll_m_) {
prefetchw(ptr[CO2_ + elt_size_ * 0]);
prefetchw(ptr[CO2_ + elt_size_ * 8]);
Expand Down Expand Up @@ -228,7 +228,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
mov(C_, ptr[rsp + get_size_of_abi_save_regs() + C_off]);
mov(LDC_, ptr[rsp + get_size_of_abi_save_regs() + LDC_off]);

if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_;
i++)
vpxorq(Xbyak::Zmm(i), Xbyak::Zmm(i), Xbyak::Zmm(i));
Expand Down Expand Up @@ -267,7 +267,8 @@ void jit_avx2_kernel_sgemm_kern::generate() {
add(AA_, A_);
mov(CO1_, C_);

if ((unroll_x == unroll_m_) || (!mayiuse(avx512_core)))
if ((unroll_x == unroll_m_)
|| !(mayiuse(avx512_core) && __BUILD_GEMM_AVX512))
lea(CO2_, ptr[C_ + LDC_ * 2]);

add(C_, unroll_x * elt_size_);
Expand All @@ -292,12 +293,12 @@ void jit_avx2_kernel_sgemm_kern::generate() {
T_NEAR);
}

if (!mayiuse(avx512_core))
if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512))
prefetcht2(ptr[AA_ - addr_off_ * elt_size_]);

switch (unroll_x) {
case 8:
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
loop<Xbyak::Zmm, Xbyak::Zmm, Xbyak::Address, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastf64x4,
Expand All @@ -319,7 +320,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
case 4:
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Address, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastf32x4,
Expand All @@ -340,7 +341,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
case 2:
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastsd,
Expand All @@ -357,7 +358,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
&Xbyak::CodeGenerator::vmovsd);
break;
case 1:
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastss,
Expand All @@ -377,7 +378,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
default:
if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vmovups,
Expand All @@ -400,7 +401,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
break;
}

if (mayiuse(avx512_core)) {
if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) {
sub(AA_, -16 * elt_size_);
} else {
if ((unroll_y != unroll_n_) || (unroll_x <= 4)) {
Expand Down
Loading

0 comments on commit a4e21bd

Please sign in to comment.