From 5b37bca3740aeb3293f5344e876f58d2f2dea964 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 12 Dec 2024 15:20:17 -0800 Subject: [PATCH] Implement basic split-k gemm for hopper matmul scheduler (#3575) This PR implements `scheduleSplitKSum` function to support split-k gemm with the hopper matmul schedule. - It support all operand formats such as TT, NT, TN, NN. --- csrc/device_lower/utils.cpp | 6 +- csrc/scheduler/hopper_multi_matmul.cpp | 89 +++++++++----------------- csrc/scheduler/hopper_multi_matmul.h | 5 ++ csrc/tensor_view.cpp | 13 +++- csrc/transform_rfactor.cpp | 5 +- tests/cpp/test_matmul_scheduler.cpp | 43 ++++++++++--- 6 files changed, 87 insertions(+), 74 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index a3d4323e761..35b825d5348 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -921,7 +921,11 @@ std::array getMmaLayout(const MmaOp* expr) { auto out_tv = ir_utils::getTv(expr->out()); IterDomain* reduction_id = nullptr; - for (auto id : out_tv->getLogicalDomain()) { + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. In the split-k case, use the root domain for the + // mma layout because the k dimension is divided into two iterDomains in the + // logical domain. + for (auto id : out_tv->getMaybeRootDomain()) { if (id->isReduction()) { reduction_id = id; break; diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index b2d8ec705ec..1efe75aeab2 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,6 +29,27 @@ namespace nvfuser { +void HopperMultipleMatmulScheduler::transformLikeMmaOutput( + TensorView* tv, + bool is_mma_result) { + // TODO Add constraints + + auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { + return (is_mma_result) ? idx - 1 : idx; + }; + + // Original: [..., Mo, No, Mi, Ni] + tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); + tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); + // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] + tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}}); + // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] + tv->merge(apply_k_dim_offset(-4)); + // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] + tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] +} + MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { ValGroup vg = graph_->toGroup(id); auto it = id_roles_.find(vg); @@ -397,22 +418,13 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { // do split-K rFactor to define splitk_sum and smem_epilogue if (params_->splitk_factor != 1) { - // TODO: schedule split-K - NVF_THROW("Hopper split-K is not yet tested"); // Note that the split-K split is already done in blockTileTensors TensorView* splitk_sum = mma_result->rFactor({-4, -1}); std::swap(splitk_sum, mma_result); splitk_sums_.push_back(splitk_sum); } - mma_result->split(-3, getM(params_->mma_macro)); - mma_result->split(-2, getN(params_->mma_macro)); - // [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki] - // -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] - mma_result->reorder({{-4, -3}}); - mma_result->merge(-5); - mma_result->axis(-4)->parallelize(ParallelType::TIDy); - + transformLikeMmaOutput(mma_result, /*is_mma_result=*/true); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( mma_result->getLoopDomain()); mma_result->setAllocationDomain(s.as(), true); @@ -509,17 +521,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Apply mma common transformation for (auto tv : {dc, d}) { - // [..., Mo, No, Mi, Ni] - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, getN(params_->mma_macro)); - // [..., Mo, No, Mio, Mii, Nio, Nii] - // -> [..., Mo, No, Mio, Nio, Mii, Nii] - tv->reorder({{-3, -2}}); - tv->merge(-4); + transformLikeMmaOutput(tv, /*is_mma_result=*/false); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( tv->getLoopDomain()); tv->setLoopDomain(s.as()); - tv->axis(-5)->parallelize(ParallelType::TIDy); } d->axis(-1)->parallelize(ParallelType::Vectorize); } @@ -565,16 +570,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Apply mma common transformation for (auto tv : {dc, d_smem, d}) { - // Original: [..., Mo, No, Mi, Ni] - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, getN(params_->mma_macro)); - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - tv->reorder({{-3, -2}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - tv->merge(-4); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - tv->axis(-3)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] + transformLikeMmaOutput(tv, /*is_mma_result=*/false); } // Schedule register cache; Output from epilogue @@ -643,41 +639,14 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { if (params_->splitk_factor == 1) { return; } - NVF_THROW("Split-K scheduling is not yet implemented for Hopper matmul"); for (TensorView* splitk_sum : splitk_sums_) { // Always use serial grid reduction for split-K sum splitk_sum->definition()->as()->requestSerialGridReduction(); - - if (params_->use_smem_epilogue) { - // Now that transforms are propagated backward to smem_epilogue, which - // is before splitk_sum, we can vectorize the inner-most non-trivial - // dimension of splitk_sum - // - // Note that the split-K reduction is the inner-most dimension. - Val* vec_ext = splitk_sum->axis(-2)->extent(); - NVF_ERROR(vec_ext->isConstInt()); - int64_t vec_ext_int = vec_ext->evaluate().as(); - splitk_sum->axis(-1)->parallelize(ParallelType::BIDz); - splitk_sum->axis(-3)->parallelize(ParallelType::TIDx); - if (vec_ext_int * dataTypeSize(splitk_sum->dtype()) > 16) { - // NOTE: We might encounter an illegal vectorization size if we are - // using Float for this reduction and Half for output. So here we - // first check whether the vectorize size is at most 16 bytes. If not, - // then we split into an unrolled loop that will do multiple - // vectorized reads/writes instead. Note that we reorder such that the - // axes are in order UR TIDx V. - splitk_sum->split( - -2, 16 / dataTypeSize(splitk_sum->dtype()), /*inner_split=*/true); - splitk_sum->axis(-3)->parallelize(ParallelType::Unroll); - splitk_sum->reorder({{-4, -3}}); - // In this case, we have [... iUR iTx rBz iS] - } - splitk_sum->reorder({{-2, -1}}); - } else { // no smem epilogue - // Reorder to place the split-K reduction next to innermost [... rBz iS] - splitk_sum->reorder({{-9, -2}}); - } - // Vectorize inner-most dimension [... (iUR iTx) rBz iV] + transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false); + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + splitk_sum->getLoopDomain()); + splitk_sum->setLoopDomain(s.as()); + splitk_sum->axis(2)->parallelize(ParallelType::BIDz); splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); } } diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 1d77785cc99..5eab0f4fbed 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -191,6 +191,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { // Return MatmulDimRole for IterDomain MatmulDimRole findMatmulDimRole(IterDomain* id); + // Schedule a block-tiled TensorView like mma output. + // Why? WGMMA has a unique output format. TensorViews after the mma-result in + // registers must respect this format for correctness. + void transformLikeMmaOutput(TensorView* tv, bool is_mma_result); + private: std::vector canonical_dim_ordering_; diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 654ff601ac7..14e0b8f746f 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -804,9 +804,12 @@ TensorView* TensorView::rFactor(const std::vector& axes) { "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. NVF_CHECK( - !domain()->hasRoot(), "Cannot call rfactor on the same view twice."); - + definition()->isA() || !domain()->hasRoot(), + "Cannot call rfactor on the same view twice."); NVF_CHECK( !definition()->isA(), "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); @@ -935,8 +938,12 @@ std::vector TensorView::rFactor( this, " its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op."); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. NVF_CHECK( - !domain()->hasRoot(), "Cannot call rfactor on the same view twice."); + definition()->isA() || !domain()->hasRoot(), + "Cannot call rfactor on the same view twice."); NVF_CHECK( definition()->outputs().size() == tvs.size(), diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 07799487eb0..709c5624935 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -340,7 +340,10 @@ std::pair TransformRFactor::runReplay( [](IterDomain* id) { return id->maybePartial(); }), "rFactor of partial domains not allowed, but at least one found."); - auto original_td_root = original_td->logical(); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. In this case, map from producer to consumer's root domain. + auto original_td_root = original_td->maybeRoot(); // Generate a new TensorDomain and set up map from one root to this one. std::vector new_producer_root(original_td_root.size(), nullptr); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 3058ce59ad7..838f96cc140 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3120,7 +3120,9 @@ using HopperMatmulSchedulerTestParams = std::tuple< int64_t, // M int64_t, // N int64_t, // K - MmaMacro>; + MmaMacro, + int64_t // SplitK Factor + >; std::string hopperTestName( const testing::TestParamInfo& info) { @@ -3129,8 +3131,16 @@ std::string hopperTestName( bool a_k_inner, b_k_inner; int64_t M, N, K; MmaMacro mma_macro; - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = - info.param; + int64_t splitk_factor; + std::tie( + use_smem_epilogue, + a_k_inner, + b_k_inner, + M, + N, + K, + mma_macro, + splitk_factor) = info.param; os << (a_k_inner ? "K" : "M"); os << (b_k_inner ? "K" : "N"); os << "_" << M << "_" << N << "_" << K; @@ -3138,6 +3148,9 @@ std::string hopperTestName( if (use_smem_epilogue) { os << "_tma_store"; } + if (splitk_factor > 1) { + os << "_splitk_" << splitk_factor; + } return os.str(); } @@ -3162,8 +3175,15 @@ class HopperMatmulSchedulerTest void SetUp() { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0); - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = - GetParam(); + std::tie( + use_smem_epilogue, + a_k_inner, + b_k_inner, + M, + N, + K, + mma_macro, + splitk_factor) = GetParam(); if (a_k_inner) { layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT; @@ -3192,11 +3212,12 @@ class HopperMatmulSchedulerTest mparams.use_smem_epilogue = use_smem_epilogue; + mparams.splitk_factor = splitk_factor; mparams.tile_sizes = gemm_tile; mparams.async_gmem_load_operands = true; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = true; - mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_stage = 2; } void TearDown() { @@ -3215,7 +3236,8 @@ class HopperMatmulSchedulerTest KernelExecutor ke; ke.compile(fusion, inputs, LaunchParams(), matmul_cparams); auto nvf_out = ke.run(inputs); - EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-5, 1e-5)); + // NOTE Relax tolerances for split-k case + EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3)); } protected: @@ -3223,6 +3245,7 @@ class HopperMatmulSchedulerTest bool a_k_inner, b_k_inner; int64_t M, N, K; MmaMacro mma_macro; + int64_t splitk_factor; std::unique_ptr fusion_up; Fusion* fusion; std::unique_ptr fusion_guard; @@ -3304,7 +3327,8 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(512), // M testing::Values(256), // N testing::Values(64), // K - testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros + testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros + testing::Values(1, 2) // SplitK Factor ), hopperTestName); @@ -3323,7 +3347,8 @@ INSTANTIATE_TEST_SUITE_P( MmaMacro::Hopper_64_128_16, MmaMacro::Hopper_64_64_16, MmaMacro::Hopper_64_32_16, - MmaMacro::Hopper_64_16_16) // mma_macros + MmaMacro::Hopper_64_16_16), // mma_macros + testing::Values(1) // SplitK Factor ), hopperTestNameSwizzle);