Skip to content

Commit

Permalink
Implement basic split-k gemm for hopper matmul scheduler (#3575)
Browse files Browse the repository at this point in the history
This PR implements `scheduleSplitKSum` function to support split-k gemm
with the hopper matmul schedule.

- It support all operand formats such as TT, NT, TN, NN.
  • Loading branch information
rdspring1 authored and jacobhinkle committed Dec 16, 2024
1 parent a8cab11 commit 5b37bca
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 74 deletions.
6 changes: 5 additions & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,11 @@ std::array<UnitDim, 2> getMmaLayout(const MmaOp* expr) {

auto out_tv = ir_utils::getTv(expr->out());
IterDomain* reduction_id = nullptr;
for (auto id : out_tv->getLogicalDomain()) {
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. In the split-k case, use the root domain for the
// mma layout because the k dimension is divided into two iterDomains in the
// logical domain.
for (auto id : out_tv->getMaybeRootDomain()) {
if (id->isReduction()) {
reduction_id = id;
break;
Expand Down
89 changes: 29 additions & 60 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
TensorView* tv,
bool is_mma_result) {
// TODO Add constraints

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
ValGroup vg = graph_->toGroup(id);
auto it = id_roles_.find(vg);
Expand Down Expand Up @@ -397,22 +418,13 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {

// do split-K rFactor to define splitk_sum and smem_epilogue
if (params_->splitk_factor != 1) {
// TODO: schedule split-K
NVF_THROW("Hopper split-K is not yet tested");
// Note that the split-K split is already done in blockTileTensors
TensorView* splitk_sum = mma_result->rFactor({-4, -1});
std::swap(splitk_sum, mma_result);
splitk_sums_.push_back(splitk_sum);
}

mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));
// [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki]
// -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki]
mma_result->reorder({{-4, -3}});
mma_result->merge(-5);
mma_result->axis(-4)->parallelize(ParallelType::TIDy);

transformLikeMmaOutput(mma_result, /*is_mma_result=*/true);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
Expand Down Expand Up @@ -509,17 +521,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Apply mma common transformation
for (auto tv : {dc, d}) {
// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}
d->axis(-1)->parallelize(ParallelType::Vectorize);
}
Expand Down Expand Up @@ -565,16 +570,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Apply mma common transformation
for (auto tv : {dc, d_smem, d}) {
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

// Schedule register cache; Output from epilogue
Expand Down Expand Up @@ -643,41 +639,14 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
if (params_->splitk_factor == 1) {
return;
}
NVF_THROW("Split-K scheduling is not yet implemented for Hopper matmul");
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();

if (params_->use_smem_epilogue) {
// Now that transforms are propagated backward to smem_epilogue, which
// is before splitk_sum, we can vectorize the inner-most non-trivial
// dimension of splitk_sum
//
// Note that the split-K reduction is the inner-most dimension.
Val* vec_ext = splitk_sum->axis(-2)->extent();
NVF_ERROR(vec_ext->isConstInt());
int64_t vec_ext_int = vec_ext->evaluate().as<int64_t>();
splitk_sum->axis(-1)->parallelize(ParallelType::BIDz);
splitk_sum->axis(-3)->parallelize(ParallelType::TIDx);
if (vec_ext_int * dataTypeSize(splitk_sum->dtype()) > 16) {
// NOTE: We might encounter an illegal vectorization size if we are
// using Float for this reduction and Half for output. So here we
// first check whether the vectorize size is at most 16 bytes. If not,
// then we split into an unrolled loop that will do multiple
// vectorized reads/writes instead. Note that we reorder such that the
// axes are in order UR TIDx V.
splitk_sum->split(
-2, 16 / dataTypeSize(splitk_sum->dtype()), /*inner_split=*/true);
splitk_sum->axis(-3)->parallelize(ParallelType::Unroll);
splitk_sum->reorder({{-4, -3}});
// In this case, we have [... iUR iTx rBz iS]
}
splitk_sum->reorder({{-2, -1}});
} else { // no smem epilogue
// Reorder to place the split-K reduction next to innermost [... rBz iS]
splitk_sum->reorder({{-9, -2}});
}
// Vectorize inner-most dimension [... (iUR iTx) rBz iV]
transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
splitk_sum->axis(2)->parallelize(ParallelType::BIDz);
splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize);
}
}
Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Return MatmulDimRole for IterDomain
MatmulDimRole findMatmulDimRole(IterDomain* id);

// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv, bool is_mma_result);

private:
std::vector<ValGroup> canonical_dim_ordering_;

Expand Down
13 changes: 10 additions & 3 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,9 +804,12 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
!domain()->hasRoot(), "Cannot call rfactor on the same view twice.");

definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");
NVF_CHECK(
!definition()->isA<GroupedReductionOp>(),
"For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
Expand Down Expand Up @@ -935,8 +938,12 @@ std::vector<TensorView*> TensorView::rFactor(
this,
" its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");

// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
!domain()->hasRoot(), "Cannot call rfactor on the same view twice.");
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");

NVF_CHECK(
definition()->outputs().size() == tvs.size(),
Expand Down
5 changes: 4 additions & 1 deletion csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,10 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
[](IterDomain* id) { return id->maybePartial(); }),
"rFactor of partial domains not allowed, but at least one found.");

auto original_td_root = original_td->logical();
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain. In this case, map from producer to consumer's root domain.
auto original_td_root = original_td->maybeRoot();

// Generate a new TensorDomain and set up map from one root to this one.
std::vector<IterDomain*> new_producer_root(original_td_root.size(), nullptr);
Expand Down
43 changes: 34 additions & 9 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3120,7 +3120,9 @@ using HopperMatmulSchedulerTestParams = std::tuple<
int64_t, // M
int64_t, // N
int64_t, // K
MmaMacro>;
MmaMacro,
int64_t // SplitK Factor
>;

std::string hopperTestName(
const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
Expand All @@ -3129,15 +3131,26 @@ std::string hopperTestName(
bool a_k_inner, b_k_inner;
int64_t M, N, K;
MmaMacro mma_macro;
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
info.param;
int64_t splitk_factor;
std::tie(
use_smem_epilogue,
a_k_inner,
b_k_inner,
M,
N,
K,
mma_macro,
splitk_factor) = info.param;
os << (a_k_inner ? "K" : "M");
os << (b_k_inner ? "K" : "N");
os << "_" << M << "_" << N << "_" << K;
os << "_MmaMacro_" << macroToString(mma_macro);
if (use_smem_epilogue) {
os << "_tma_store";
}
if (splitk_factor > 1) {
os << "_splitk_" << splitk_factor;
}
return os.str();
}

Expand All @@ -3162,8 +3175,15 @@ class HopperMatmulSchedulerTest
void SetUp() {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);

std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
GetParam();
std::tie(
use_smem_epilogue,
a_k_inner,
b_k_inner,
M,
N,
K,
mma_macro,
splitk_factor) = GetParam();

if (a_k_inner) {
layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT;
Expand Down Expand Up @@ -3192,11 +3212,12 @@ class HopperMatmulSchedulerTest

mparams.use_smem_epilogue = use_smem_epilogue;

mparams.splitk_factor = splitk_factor;
mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
}

void TearDown() {
Expand All @@ -3215,14 +3236,16 @@ class HopperMatmulSchedulerTest
KernelExecutor ke;
ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
auto nvf_out = ke.run(inputs);
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-5, 1e-5));
// NOTE Relax tolerances for split-k case
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3));
}

protected:
bool use_smem_epilogue;
bool a_k_inner, b_k_inner;
int64_t M, N, K;
MmaMacro mma_macro;
int64_t splitk_factor;
std::unique_ptr<Fusion> fusion_up;
Fusion* fusion;
std::unique_ptr<FusionGuard> fusion_guard;
Expand Down Expand Up @@ -3304,7 +3327,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(512), // M
testing::Values(256), // N
testing::Values(64), // K
testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros
testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros
testing::Values(1, 2) // SplitK Factor
),
hopperTestName);

Expand All @@ -3323,7 +3347,8 @@ INSTANTIATE_TEST_SUITE_P(
MmaMacro::Hopper_64_128_16,
MmaMacro::Hopper_64_64_16,
MmaMacro::Hopper_64_32_16,
MmaMacro::Hopper_64_16_16) // mma_macros
MmaMacro::Hopper_64_16_16), // mma_macros
testing::Values(1) // SplitK Factor
),
hopperTestNameSwizzle);

Expand Down

0 comments on commit 5b37bca

Please sign in to comment.