Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Accept Hopper matmuls and update default heuristic #3579

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 138 additions & 47 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,41 @@ using ProblemShape = std::array<int64_t, 4>;
inline std::optional<MmaMacro> getMmaOp(
const int dev_version,
const ProblemShape& problem) {
using MacroType = MmaMacro;
const int64_t n_extent = problem[(size_t)MatmulDimRole::N];

// NOTE: A temp condition
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N];
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0);
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16};

switch (dev_version) {
case 75:
return (use_small_n) ? MacroType::Turing_16_8_16
: MacroType::Turing_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Turing;
if ((n_extent % 16) != 0) {
macro_encode.n = 16;
}
break;
case 80:
case 86:
case 89:
case 90: // NOTE: temp use ampere matmul for hopper
return (use_small_n) ? MacroType::Ampere_16_8_16
: MacroType::Ampere_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Ampere;
if ((n_extent % 16) != 0) {
macro_encode.n = 16;
}
break;
case 90:
macro_encode.arch = MmaMacroEncode::Arch::Hopper;
macro_encode.m = 64;
// Find the largest instruction tile that divides the problem size
macro_encode.n = 256;
while (macro_encode.n >= 8) {
macro_encode.n -= 8;
if (n_extent % macro_encode.n == 0) {
break;
}
}
break;
default:
return std::nullopt;
}
return macro_encode;
}

//! Find the number of circular buffer stages for shared memory operands, so
Expand Down Expand Up @@ -105,52 +121,124 @@ inline bool initCoreHeuristics(

using DimType = decltype(GemmTile::m);

// warp tile shape
{
// Initial target:
// - 1 MMA ops per thread in a warp (32 threads), warp tile should be
// then 32x bigger than instruction tile,
// - start with [4, 4, 2] shape, later it should depend on problem
// shape and have bigger impact on CTA tile shape

const DimType m_ratio = 4;
const DimType n_ratio = 4;
const DimType k_ratio = 2;

warp_tile = {
instruction_tile.m * m_ratio,
instruction_tile.n * n_ratio,
instruction_tile.k * k_ratio};
}
if (isHopper(mparams->mma_macro)) {
// We typically use larger macros on Hopper. By default we will set the
// warp tile equal to the macro and increase the CTA tile until we hit
// a limit. The limits are given by the maximum number of threads per CTA.

// TODO: it might be advantageous in some cases to issue multiple wgmma
// instructions per warp group
warp_tile = instruction_tile;

// The MmaOp output is a 32-bit float which requires one register per value

const DimType registers_per_wg = warp_tile.m * warp_tile.n;
const DimType max_registers_per_sm = 512 * 100;

const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
DimType num_warp_groups = m_ratio * n_ratio;
return cta_n * cta_m < max_registers_per_sm
// Each warp group is 128 threads. We can only have a maximum of 1024
// threads per SM, or 8 warp groups.
&& num_warp_groups <= 8 &&
// Don't extend the CTA tile beyond the problem size
warp_tile.m * (m_ratio + 1) <=
problem_shape[(size_t)MatmulDimRole::M] &&
warp_tile.n * (n_ratio + 1) <=
problem_shape[(size_t)MatmulDimRole::N];
};

// cta tile shape
{
// Initial target:
// - 4 warp tiles per CTA
// - CTA k-dim should be same as warp tile k-dim
DimType m_ratio = 1;
DimType n_ratio = 1;

bool increased = true;
while (increased) {
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
increased = false;

DimType m_ratio = 2;
DimType n_ratio = 2;
const auto tryIncreaseM = [&]() {
if (ratiosValid(m_ratio + 1, n_ratio)) {
m_ratio++;
increased = true;
}
return increased;
};
const auto tryIncreaseN = [&]() {
if (ratiosValid(m_ratio, n_ratio + 1)) {
n_ratio++;
increased = true;
}
return increased;
};

const auto mn_ratio = (double)problem_shape[(size_t)MatmulDimRole::M] /
(double)problem_shape[(size_t)MatmulDimRole::N];
if (mn_ratio < 0.5) {
m_ratio = 1;
n_ratio = 4;
} else if (mn_ratio > 2) {
m_ratio = 4;
n_ratio = 1;
if (cta_m < cta_n) {
// Try to increase smaller tile dimension first since square tiles are
// optimal for reducing operand load redundancy
if (tryIncreaseM()) {
continue;
}
tryIncreaseN();
} else {
if (tryIncreaseN()) {
continue;
}
tryIncreaseM();
}
}

cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k};

} else {
// warp tile shape
{
// Initial target:
// - 1 MMA ops per thread in a warp (32 threads), warp tile should be
// then 32x bigger than instruction tile,
// - start with [4, 4, 2] shape, later it should depend on problem
// shape and have bigger impact on CTA tile shape

const DimType m_ratio = isHopper(mparams->mma_macro) ? 1 : 4;
const DimType n_ratio = isHopper(mparams->mma_macro) ? 1 : 4;
const DimType k_ratio = isHopper(mparams->mma_macro) ? 1 : 2;

warp_tile = {
instruction_tile.m * m_ratio,
instruction_tile.n * n_ratio,
instruction_tile.k * k_ratio};
}

// cta tile shape
{
// Initial target:
// - 4 warp tiles per CTA
// - CTA k-dim should be same as warp tile k-dim

DimType m_ratio = 2;
DimType n_ratio = 2;

const auto mn_ratio = (double)problem_shape[(size_t)MatmulDimRole::M] /
(double)problem_shape[(size_t)MatmulDimRole::N];
if (mn_ratio < 0.5) {
m_ratio = 1;
n_ratio = 4;
} else if (mn_ratio > 2) {
m_ratio = 4;
n_ratio = 1;
}

cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k};
}
}

mparams->tile_sizes = {cta_tile, warp_tile};

// stages and async mem copy
{
// NOTE: compilation errors when async is enabled on Turing devices
if (isAmpere(mparams->mma_macro)) {
if (!isTuring(mparams->mma_macro)) {
constexpr int stages = 3;

mparams->circular_buffer_options.circular_buffer_smem_write = true;
Expand All @@ -169,11 +257,14 @@ inline bool initCoreHeuristics(
}
return min_size_bytes;
};
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported(
mparams,
std::min(
roleMinDtypeSize(MatmulTensorRole::OPERAND_A),
roleMinDtypeSize(MatmulTensorRole::OPERAND_B)));
// Use TMA on Hopper+ or cp.async on Ampere if possible
mparams->async_gmem_load_operands =
isHopper(mparams->mma_macro) ||
isCpAsyncOperandLoadSupported(
mparams,
std::min(
roleMinDtypeSize(MatmulTensorRole::OPERAND_A),
roleMinDtypeSize(MatmulTensorRole::OPERAND_B)));

if (!mparams->async_gmem_load_operands) {
// Circular buffering requires async load. If we cannot use async load due
Expand Down
Loading