From 7f7126d2d3bddaf22e0bcb77292dffa8c4958288 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 8 May 2024 12:06:36 -0700 Subject: [PATCH] Fix schedule and lowering of Hopper MMAs (#2176) This PR fixes the schedule and lowering of Hopper MMAs. Please see my own review on this PR for detail. After https://github.com/NVIDIA/Fuser/pull/2194, the schedule and lowering of MMA finally makes sense now. Note that MMA is still limited to have only one tile in smem. Supporting multiple tiles is the next step. --- csrc/device_lower/pass/index.cpp | 42 +++++++++++++++++---------- csrc/device_lower/pass/inline_ptx.cpp | 12 ++++---- csrc/ir/interface_nodes.h | 6 +--- csrc/scheduler/mma_utils.cpp | 20 ++----------- csrc/scheduler/mma_utils.h | 7 +---- csrc/tensor_view.cpp | 8 ++--- tests/cpp/test_mma.cpp | 16 ++-------- 7 files changed, 42 insertions(+), 69 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 7505d0bc361..c629afddc47 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1575,6 +1575,7 @@ static MmaInputSmemSwizzle getSwizzleMode(TensorView* tv) { // Reference for smem strides: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#strides void IndexLowering::handle(const MmaOp* mma) { + constexpr int64_t core_matrix_outer_size = 8; Val* a = nullptr; Val* b = nullptr; if (mma->inA()->as()->getMemoryType() == MemoryType::Shared) { @@ -1583,13 +1584,19 @@ void IndexLowering::handle(const MmaOp* mma) { auto tv = mma->inA()->as(); auto base_addr = IrBuilder::baseAddressExpr(tv); auto swizzle = getSwizzleMode(tv); - int64_t stride_bytes = - 8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes - int64_t leading_bytes = /*8x8 items each core matrix*/ 64L * - /*number of core matrices*/ (getM(mma->macro()) / 8L) * - /*bytes per item*/ 2L; - if (swizzle != MmaInputSmemSwizzle::None) { - // TODO: why???!!! + int64_t leading_bytes = core_matrix_outer_size * + getBytesFromSwizzle(swizzle); // swizzle period in bytes + int64_t inner_size = + (mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN) + ? getK(mma->macro()) + : getM(mma->macro()); + int64_t stride_bytes = core_matrix_outer_size * + /*number of core matrices, rounded up to handle padding */ + roundUpToMultiple(inner_size * /*bytes per item*/ 2L, + getBytesFromSwizzle(swizzle)); + if (swizzle == MmaInputSmemSwizzle::None && + (mma->layout() == MmaLayout::NT || mma->layout() == MmaLayout::NN)) { + // tnspA and tnspB is ignored for NoSwizzle mode std::swap(leading_bytes, stride_bytes); } auto matrix_desc = constructMatrixDescriptor( @@ -1612,16 +1619,19 @@ void IndexLowering::handle(const MmaOp* mma) { auto tv = mma->inB()->as(); auto swizzle = getSwizzleMode(tv); auto base_addr = IrBuilder::baseAddressExpr(tv); - int64_t stride_bytes = - 8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes - int64_t leading_bytes = /*8x8 items each core matrix*/ 64L * + int64_t leading_bytes = core_matrix_outer_size * + getBytesFromSwizzle(swizzle); // swizzle period in bytes + int64_t inner_size = + (mma->layout() == MmaLayout::TN || mma->layout() == MmaLayout::NN) + ? getK(mma->macro()) + : getN(mma->macro()); + int64_t stride_bytes = core_matrix_outer_size * /*number of core matrices, rounded up to handle padding */ - roundUpToMultiple(getN(mma->macro()) / 8L, - getBytesFromSwizzle(swizzle) / 16L) * - /*bytes per item*/ 2L; - if (swizzle != MmaInputSmemSwizzle::None && - (mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN)) { - // TODO: why???!!! + roundUpToMultiple(inner_size * /*bytes per item*/ 2L, + getBytesFromSwizzle(swizzle)); + if (swizzle == MmaInputSmemSwizzle::None && + (mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::NT)) { + // tnspA and tnspB is ignored for NoSwizzle mode std::swap(leading_bytes, stride_bytes); } auto matrix_desc = constructMatrixDescriptor( diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 950bc0bcffc..8541bae7382 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -223,18 +223,18 @@ class LowerToInlinePtx : public kir::ExprMutator { /*scaleB=*/IrBuilder::create(1, DataType::Int32)}; auto layout = *mma->layout(); if (a_on_smem) { - // tnspA: if not K-major, then needs transpose + // tnspA if (layout == MmaLayout::TT || layout == MmaLayout::TN) { - inputs.push_back(IrBuilder::create(1, DataType::Int32)); - } else { inputs.push_back(IrBuilder::create(0, DataType::Int32)); + } else { + inputs.push_back(IrBuilder::create(1, DataType::Int32)); } } - // tnspB: if not K-major, then needs transpose + // tnspB if (layout == MmaLayout::TN || layout == MmaLayout::NN) { - inputs.push_back(IrBuilder::create(1, DataType::Int32)); - } else { inputs.push_back(IrBuilder::create(0, DataType::Int32)); + } else { + inputs.push_back(IrBuilder::create(1, DataType::Int32)); } registerInsertBefore( mma, diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 5a38a51c194..a6b4cfab4d5 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -419,11 +419,7 @@ class NVF_API TensorView : public Val { //! have a matching thread swizzle with the mma operand/result. //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . void applyMmaSwizzle(MmaOperand operand); - // TODO: what is transpose 2? Why do we need it? - void applyMmaSwizzle( - MmaInputSmemSwizzle swizzle, - bool transpose, - bool transpose2 = false); + void applyMmaSwizzle(MmaInputSmemSwizzle swizzle); //! Returns if this tensor view has swizzle operator on its tensor domain. //! This is the temporary flag for indicating that the new swizzle diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 8661d3b1485..ee7be460059 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -853,12 +853,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOperand operand) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-swizzling-modes void WarpMmaSwizzler::scheduleOperandRead( TensorView* tv, - MmaInputSmemSwizzle swizzle, - bool transpose, - bool transpose2) { - if (transpose) { - tv->reorder({{-2, -1}}); - } + MmaInputSmemSwizzle swizzle) { if (swizzle == MmaInputSmemSwizzle::None) { // For no-swizzle case, the entire tile are divided into 8x8 core matrices, // and each core matrix resides in a contiguous 8*8*2 bytes region in shared @@ -868,16 +863,9 @@ void WarpMmaSwizzler::scheduleOperandRead( // [Ko, K8, Mo, M8] tv->reorder({{-2, -3}}); // [Ko, Mo, K8, M8] - if (transpose2) { - tv->reorder({{-2, -1}}); - } } else { auto swizzle_size = getBytesFromSwizzle(swizzle) / 16; // For example, [K, M] - if (transpose2) { - tv->reorder({{-2, -1}}); - // [M, K] - } tv->split(-2, 8); tv->split(-1, 8); // For example transpose2 == false @@ -891,10 +879,8 @@ void WarpMmaSwizzler::scheduleOperandRead( tv->split(-4, 8 / swizzle_size); // [Ko, K2, K4, Moo, Mo2, M8] tv->swizzle(SwizzleType::XOR, -5, -2); - if (!transpose2) { - tv->reorder({{-3, -5}}); - // [Ko, Moo, K2, K4, Mo2, M8] - } + tv->reorder({{-3, -5}}); + // [Ko, Moo, K2, K4, Mo2, M8] } tv->setAllocationDomain(tv->getLeafDomain(), true); } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index a0c5126b648..8b2a6861bdf 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -173,12 +173,7 @@ class WarpMmaSwizzler { //! after smem read. //! The rightmost iterdomains must follow the m,n,k convention before calling. static void scheduleOperandRead(TensorView* tv, MmaOperand operand); - // TODO: what is transpose2? Why do we need it? - static void scheduleOperandRead( - TensorView* tv, - MmaInputSmemSwizzle swizzle, - bool transpose, - bool transpose2); + static void scheduleOperandRead(TensorView* tv, MmaInputSmemSwizzle swizzle); //! Note [schedule of ldmatrix] //! If you look at the doc of ldmatrix and mma for Turing and Ampere: diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 56c107385df..9303656f1ee 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -1313,15 +1313,11 @@ void TensorView::applyMmaSwizzle(MmaOperand operand) { } } -void TensorView::applyMmaSwizzle( - MmaInputSmemSwizzle swizzle, - bool transpose, - bool transpose2) { +void TensorView::applyMmaSwizzle(MmaInputSmemSwizzle swizzle) { NVF_ERROR( getMemoryType() == MemoryType::Shared, "Shared memory swizzle is only supported for shared memory"); - mma_utils::WarpMmaSwizzler::scheduleOperandRead( - this, swizzle, transpose, transpose2); + mma_utils::WarpMmaSwizzler::scheduleOperandRead(this, swizzle); } void TensorView::commitLeafToRFactor() { diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 1480ca6303a..bd3f1e563b3 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -349,9 +349,6 @@ class HopperRS : public HopperBase, dtype = std::get<1>(GetParam()); layout = std::get<2>(GetParam()); swizzle_b = std::get<3>(GetParam()); - if (layout != MmaLayout::TN) { - GTEST_SKIP() << "bugs to be fixed"; - } } }; @@ -417,7 +414,7 @@ TEST_P(HopperRS, SingleTile) { tv0->merge(1); tv0->axis(1)->parallelize(ParallelType::TIDx); - tv1->applyMmaSwizzle(swizzle_b, layout == MmaLayout::TN); + tv1->applyMmaSwizzle(swizzle_b); naivelyParallelize(tv1); @@ -490,9 +487,6 @@ class HopperSS : public HopperBase, layout = std::get<2>(GetParam()); swizzle_a = std::get<3>(GetParam()); swizzle_b = std::get<4>(GetParam()); - if (layout != MmaLayout::TN) { - GTEST_SKIP() << "bugs to be fixed"; - } } }; @@ -528,9 +522,6 @@ TEST_P(HopperSS, SingleTile) { Fusion fusion; FusionGuard fg(&fusion); - bool transpose_a = (layout == MmaLayout::NT || layout == MmaLayout::NN); - bool transpose_b = (layout == MmaLayout::TN || layout == MmaLayout::NN); - auto shapes = matmulAtInputShape3DHopperSS( getM(macro), getN(macro), getK(macro), layout); @@ -602,9 +593,8 @@ TEST_P(HopperSS, SingleTile) { moveInnerBroadcastLeft(tv1); // Hopper tensor core assumes K major, so we are using !transpose_a here. - tv0->applyMmaSwizzle(swizzle_a, !transpose_a); - tv1->setMemoryType(MemoryType::Shared); - tv1->applyMmaSwizzle(swizzle_b, transpose_b, transpose_a); + tv0->applyMmaSwizzle(swizzle_a); + tv1->applyMmaSwizzle(swizzle_b); naivelyParallelize(tv0); naivelyParallelize(tv1);