Skip to content

Commit

Permalink
[CPU]SDPA uses Brgemm Kernel for BF16 (openvinotoolkit#22980)
Browse files Browse the repository at this point in the history
### Details:
- *ScaleDotProductAttention uses brgemm kernel for 1st token
calculation*

### Tickets:
 - *CVS-131056*
  • Loading branch information
zhangYiIntel authored Feb 28, 2024
1 parent 160b193 commit 96b72d0
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 195 deletions.
39 changes: 26 additions & 13 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,39 @@ using namespace dnnl::impl::cpu::x64::matmul;
namespace ov {
namespace intel_cpu {

BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc, bool b_transposed)
BrgemmKernel::BrgemmKernel(size_t M,
size_t N,
size_t K,
size_t lda,
size_t ldb,
size_t ldc,
bool b_transposed,
ov::element::Type inType)
: M(M),
K(K),
N(N),
lda(lda),
ldb(ldb),
ldc(ldc),
b_transposed(b_transposed),
inType(ov::element::bf16) {
inType(inType) {
// blocking M
M_blk = matmulOptimalM;
M_tail = M % M_blk;
brgVnniFactor = 4 / inType.size();
if (!mayiuse(avx512_core_bf16))
THROW_ERROR("brgemm bf16bf16f32 kernel could only be used above avx512_bf16");
bool isAMXSupported = mayiuse(avx512_core_amx);
// blocing N

if (inType != ov::element::bf16 && inType != ov::element::f32)
THROW_ERROR("brgemm kernel only supports bf16, f32");
bool is_bf16 = inType == ov::element::bf16;
if (is_bf16 && !mayiuse(avx512_core_bf16))
THROW_ERROR("brgemm bf16 kernel could only be used above avx512_bf16");

bool isAMXSupported = is_bf16 && mayiuse(avx512_core_amx);
// blocking N
N_blk = is_bf16 ? 32 : N;
N_tail = N % N_blk;

// blocing K
// blocking K
K_blk = isAMXSupported ? 32 : K;
K_tail = K % K_blk;
if (isAMXSupported && K_tail) {
Expand All @@ -60,7 +73,7 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb,
brgemmCtx.N = N_;
brgemmCtx.K = K_;
brgemmCtx.LDA = k ? K_blk : lda;
brgemmCtx.LDB = rnd_up(N, N_blk); // B is copied with bf16
brgemmCtx.LDB = (is_bf16 || b_transposed) ? rnd_up(N, N_blk) : ldb; // bf16/b_transposed needs copy
brgemmCtx.LDC = ldc;
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(inType));
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(inType));
Expand All @@ -83,7 +96,7 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb,
packedASize = M_blk * rnd_up(K, K_blk) * inType.size();
}

if (brgemmCtx0.is_with_amx || inType == ov::element::bf16) {
if (brgemmCtx0.is_with_amx || inType == ov::element::bf16 || b_transposed) {
size_t b_stride = 0;
// must set actual stride when stride is not K/N
if (b_transposed) {
Expand Down Expand Up @@ -267,11 +280,11 @@ void BrgemmKernel::copy_buffer_b(void* b, void* scratch_b) {
}
}

void BrgemmKernel::executeGemmPackedB(bool is_M_tail, void* a, void* repacked_b, void* c, void* wsp, void* scratch_a) {
void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a) {
auto ptr_A = reinterpret_cast<uint8_t*>(a);
auto ptr_C = reinterpret_cast<uint8_t*>(c);
auto ptr_scartch_a = reinterpret_cast<uint8_t*>(scratch_a);
auto ptr_scartch_b = reinterpret_cast<uint8_t*>(repacked_b);
auto ptr_scartch_b = reinterpret_cast<uint8_t*>(b);
uint8_t* ptr_a_tail = nullptr;

size_t brgIdx0 = getBrgIdx(0, 0, 0);
Expand Down Expand Up @@ -340,8 +353,8 @@ void BrgemmKernel::executeGemm(void* a, void* b, void* c, void* wsp, void* scrat
for (size_t mb = 0; mb < div_up(M, M_blk); mb++) {
const bool is_M_tail = (M - mb * M_blk < M_blk);
auto ptr_a = ptr_A + (mb * M_blk * lda) * inType.size();
auto ptr_c = ptr_C + (mb * M_blk * ldc) * inType.size();
executeGemmPackedB(is_M_tail, ptr_a, scratch_b, wsp, ptr_c, scratch_a);
auto ptr_c = ptr_C + (mb * M_blk * ldc) * ov::element::f32.size();
executeGemm(is_M_tail, ptr_a, scratch_b, wsp, ptr_c, scratch_a);
}
}
void BrgemmKernel::callBrgemm(brgemmCtx& ctx,
Expand Down
16 changes: 11 additions & 5 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ class BrgemmKernel {
// ldb is the leading dimension for B matrix
// ldc is the leading dimension for C matrix
// b_transpose indicates wheter B matrix is transposed.
BrgemmKernel(size_t M, size_t N, size_t K, size_t lda, size_t ldb, size_t ldc, bool b_transposed = false);
BrgemmKernel(size_t M,
size_t N,
size_t K,
size_t lda,
size_t ldb,
size_t ldc,
bool b_transposed = false,
ov::element::Type inType = ov::element::bf16);
// execute all M
void executeGemm(void* a, void* b, void* c, void* wsp, void* scratch_a, void* scratch_b);
// execute m_blk
void executeGemmPackedB(bool is_M_tail, void* a, void* repacked_b, void* c, void* wsp, void* scratch_a);
// execute by m_blk
void executeGemm(bool is_M_tail, void* a, void* b, void* c, void* wsp, void* scratch_a);

void copy_buffer_b(void* b, void* scratch_b);
// bytes needed to place scratch buffer a
Expand All @@ -44,14 +51,13 @@ class BrgemmKernel {

private:
size_t M = 0, M_blk = 0, M_tail = 0;
size_t K = 0, K_blk = 0, K_tail = 0, N = 0, N_tail = 0;
size_t K = 0, K_blk = 0, K_tail = 0, N = 0, N_blk = 0, N_tail = 0;
size_t lda = 0, ldb = 0, ldc = 0;
bool b_transposed = false;
size_t brgVnniFactor = 0;
size_t packedBSize = 0;
size_t packedASize = 0;
ov::element::Type inType;
static constexpr size_t N_blk = 32;
static constexpr size_t MHA_BRGEMM_KERNELS_NUM = 8;
static constexpr size_t matmulOptimalM = 32;
struct brgemmCtx {
Expand Down
Loading

0 comments on commit 96b72d0

Please sign in to comment.