Skip to content

Commit

Permalink
feat(kernel): add AArch64 MLA int8 MK4 MatMul kernel
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 442b573a81b21f81c9d76e52a6d0eed3ec860e0d
  • Loading branch information
megvii-mge committed Apr 7, 2024
1 parent 9f5e6e3 commit b37dead
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 1 deletion.
3 changes: 2 additions & 1 deletion compiler/lib/KernelGen/Arm/Arm64/KernelPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ struct AllA64Kernel {
std::make_shared<Arm64::Fp32MatMulM8N12>(),
std::make_shared<Arm64::Fp32MatMulM8N12K4>(),
std::make_shared<Arm64::Fp32MatMulM4N16K4>(),
std::make_shared<Arm64::Int8DotMatMulM8N12K4>()};
std::make_shared<Arm64::Int8DotMatMulM8N12K4>(),
std::make_shared<Arm64::Int8MK4MatMulM4N4K16>()};
inner_i8mm_map[KernelPack::KernType::MatrixMulKernel] = {
std::make_shared<Arm64::Int8I8mmMatMulM8N12K8MK4>()};

Expand Down
141 changes: 141 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/Int8MK4MatMulM4N4K16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "../InternalKernel/InternalKernel.h"
#include "MatMul.h"
#include "Utils/StringTemplate.h"
#include "Utils/Utils.h"
#include "compiler/Common/Logger.h"

using namespace megcc;
using namespace KernelGen;
using namespace Arm64;
std::shared_ptr<TContext> Int8MK4MatMulM4N4K16::GetInnerCtx(TContext* ctx) const {
auto inner_ctx = std::make_shared<CodeGenContext>();
inner_ctx->setAttr("format", "MK4");
inner_ctx->setAttr("with_bias", false);
inner_ctx->setAttr("transposeA", false);
inner_ctx->setAttr("transposeB", false);
inner_ctx->setAttr("dtype", "8832");
return inner_ctx;
}

bool Int8MK4MatMulM4N4K16::IsAvailable(TContext* context) const {
bool ok_dtype = Utils::is_int_dtype(context->getAttrOprand("operand:0").dtype, 8) &&
Utils::is_int_dtype(context->getAttrOprand("operand:1").dtype, 8) &&
Utils::is_int_dtype(context->getAttrOprand("operand:2").dtype, 32);
bool ok_mode = context->getAttrStr("format") == "MK4" &&
context->getAttrStr("compute_mode") == "DEFAULT";
bool ok_shape = context->getAttrOprand("operand:0").shape.size() == 4 &&
context->getAttrOprand("operand:1").shape.size() == 3;
bool ok_tran = context->getAttrBool("transposeA") == false &&
context->getAttrBool("transposeB") == false;
return ok_dtype && ok_mode && ok_shape && ok_tran;
}
//! kernel gen
std::string Int8MK4MatMulM4N4K16::GetKernelSymbol(TContext* context) const {
std::stringstream ss;
ss << "Arm64_kernel_int8_matmul_m4n4k16_mk4_";
if (context->getAttrBool("transposeA")) {
ss << "t";
} else {
ss << "n";
}
if (context->getAttrBool("transposeB")) {
ss << "t";
} else {
ss << "n";
}
return ss.str();
}

std::vector<KernelObj> Int8MK4MatMulM4N4K16::GetDependInternalSymbol(
TContext* context) const {
auto inner_ctx = GetInnerCtx(context);
return {
{m_inner_matmul.GetKernelSymbol(inner_ctx.get()),
m_inner_matmul.GetKernelBody(inner_ctx.get()),
m_inner_matmul.GetBodyGuardBegin(inner_ctx.get()),
m_inner_matmul.GetBodyGuardEnd(inner_ctx.get()),
m_inner_matmul.GetDependInternalSymbol(inner_ctx.get())}};
}

std::string Int8MK4MatMulM4N4K16::GetWorkspaceBodyCondition(
TContext* ctx, bool jit) const {
std::stringstream ss;
auto inner_ctx = GetInnerCtx(ctx);
if (jit) {
ss << m_inner_matmul.GetPackAWorkspaceBody(inner_ctx.get()) << ";\n";
ss << m_inner_matmul.GetPackBWorkspaceBody(inner_ctx.get()) << ";\n";
} else {
ss << "extern " << m_inner_matmul.GetPackAWorkspaceSignature(inner_ctx.get())
<< ";\n";
ss << "extern " << m_inner_matmul.GetPackBWorkspaceSignature(inner_ctx.get())
<< ";\n";
}
ss << GenCommonRet() << " " << GetWorkspaceSignature(ctx);
std::string workspace_temp =
R"({
TINYNN_ASSERT(workspace);
const Layout a_layout = inputs[0]->layout;
const Layout b_layout = inputs[1]->layout;
const size_t M = a_layout.dims[0] * 4;
const size_t K = a_layout.dims[1] * 4;
const size_t N = b_layout.dims[1];
*workspace = ${packa_workspace_sym}(0, M, 0, K) + ${packb_workspace_sym}(0, N, 0, K);
return TinyNN_SUCCESS;
})";
ss << StringTemplate::StringTemplateArgs()
.add("packa_workspace_sym",
m_inner_matmul.GetPackAWorkspaceSymbol(inner_ctx.get()))
.add("packb_workspace_sym",
m_inner_matmul.GetPackBWorkspaceSymbol(inner_ctx.get()))
.render(workspace_temp);
return ss.str();
}

std::string Int8MK4MatMulM4N4K16::GetKernelBody(TContext* context) const {
std::stringstream writer;
auto inner_ctx = GetInnerCtx(context);
writer << "#include <arm_neon.h>\n";
writer << "extern " << m_inner_matmul.GetKernelSignature(inner_ctx.get()) << ";\n";
writer << GenCommonRet() << " ";
writer << GetKernelSignature(context);
std::string body_temp = R"({
int8_t* A = (int8_t*)inputs[0]->ptr;
int8_t* B = (int8_t*)inputs[1]->ptr;
int32_t* C = (int32_t*)outputs[0]->ptr;
TINYNN_ASSERT(A);
TINYNN_ASSERT(B);
TINYNN_ASSERT(C);
const Layout a_layout = inputs[0]->layout;
const Layout b_layout = inputs[1]->layout;
const Layout c_layout = outputs[0]->layout;
const size_t LDA = a_layout.stride[0];
const size_t LDB = b_layout.stride[0];
const size_t LDC = c_layout.stride[0];
const size_t M = a_layout.dims[0] * 4;
const size_t K = a_layout.dims[1] * 4;
const size_t N = c_layout.dims[1];
TINYNN_ASSERT(4 == a_layout.dims[3]);
TINYNN_ASSERT(4 == a_layout.dims[2]);
TINYNN_ASSERT(4 == b_layout.dims[2]);
TINYNN_ASSERT(4 == c_layout.dims[2]);
TINYNN_ASSERT(a_layout.dims[0] == c_layout.dims[0]);
TINYNN_ASSERT(a_layout.dims[1] == b_layout.dims[0]);
TINYNN_ASSERT(b_layout.dims[1] == b_layout.dims[1]);
void* workspace_ptr = workspace->ptr;
TINYNN_ASSERT(workspace_ptr);
${matmul_symbol}(A, LDA, B, LDB, C, LDC, M, N, K, 0, workspace_ptr, 1.f, 1.f, 1.f);
return TinyNN_SUCCESS;
})";

writer << StringTemplate::StringTemplateArgs()
.add("matmul_symbol",
m_inner_matmul.GetKernelSymbol(inner_ctx.get()))
.render(body_temp);
return writer.str();
}

// vim: syntax=cpp.doxygen
19 changes: 19 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/MatMul.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ class Int8DotMatMulM8N12K4 : public Arm64KernelFunc {
std::shared_ptr<TContext> GetInnerCtx(TContext* ctx) const;
};

class Int8MK4MatMulM4N4K16 : public Arm64KernelFunc {
public:
bool IsAvailable(TContext* context) const override;
std::string GetKernelSymbol(TContext* context) const override;
std::string GetKernelBody(TContext* context) const override;
std::vector<KernelObj> GetDependInternalSymbol(TContext* context) const override;
std::string GetWorkspaceBody(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, false);
}
std::string GetWorkspaceBodyAndJitExec(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, true);
}

private:
std::string GetWorkspaceBodyCondition(TContext* ctx, bool jit) const;
std::shared_ptr<TContext> GetInnerCtx(TContext* ctx) const;
MatmulInt8M4N4K16MK4Kernel m_inner_matmul;
};

class Int8I8mmMatMulM8N12K8MK4 : public Arm64KernelFunc {
public:
bool IsAvailable(TContext* context) const override;
Expand Down
22 changes: 22 additions & 0 deletions compiler/test/kernel/opr/arm/Int8Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@ TEST(AARCH64, Int8MatMulM8N12K4Dot) {
}
}

TEST(AARCH64, Int8MK4MatMulM4N4K16) {
Checker<MatrixMulForward> checker(Arch::ARM64);
MatrixMulForward::Param param;
UniformIntRNG rng(-127, 127);
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);

checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int32());
checker.set_kernel_symbol("Arm64_kernel_int8_matmul_m4n4k16_mk4_.*");
for (size_t m : {4, 8, 16, 64})
for (size_t n : {3, 8, 15, 56})
for (size_t k : {8, 16, 32, 80}) {
param.transposeA = false;
param.transposeB = false;
param.format = param::MatrixMul::Format::MK4;
checker.set_param(param);
checker.execs({{m / 4, k / 4, 4, 4}, {k / 4, n, 4}, {}});
}
}

TEST(AARCH64, Int8MatMulM8N12K8MK4I8mm) {
Checker<MatrixMulForward> checker(Arch::ARM64_WITH_I8MM);
MatrixMulForward::Param param;
Expand Down
9 changes: 9 additions & 0 deletions compiler/test/kernel/opr/arm/benchmark_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ TEST(AARCH64, BenchmarkConv1x1NCHW4Dot) {
ConvBiasForward::Param ::NonlineMode::RELU);
}

TEST(AARCH64, BenchmarkConv1x1NCHW44Int8) {
std::string cc_algo = "Arm64_kernel_conv2d_conv1x1_.*";
std::string dnn_algo = "";
run_conv(
1, 120, 120, 96, 1, 1, 0, cc_algo, dnn_algo,
ConvBiasForward::Param::Format::NCHW44, true,
ConvBiasForward::Param ::NonlineMode::RELU);
}

TEST(AARCH64, BenchmarkConvDotNCHWNCHW44Stride1) {
Benchmarker<ConvBiasForward> benchmarker(Arch::ARM64);
ConvBiasForward::Param param;
Expand Down

0 comments on commit b37dead

Please sign in to comment.