From c7c5de8c94d43374b44f45c62ed0a2aeffbc9348 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 7 Oct 2024 09:35:57 -0700 Subject: [PATCH] Cleanup `applyMmaSwizzleForTMALoad` (#3077) In the past, we had a `permute_outer_dim` parameter for `applyMmaSwizzleForTMALoad` because strides in the Mma Op's matrix descriptor were hard-coded. With hardcoded strides, we had zero flexibility, and had to use multiple TMA to load smem buffer into a very specific format. After https://github.com/NVIDIA/Fuser/pull/3002, the strides were computed based on the schedule, so we do have the flexibility to choose a memory layout that uses less TMA instructions to load. --- csrc/ir/interface_nodes.h | 4 +- csrc/scheduler/mma_utils.cpp | 18 +------- csrc/scheduler/mma_utils.h | 3 +- csrc/tensor_view.cpp | 7 +-- tests/cpp/test_mma.cpp | 85 ------------------------------------ 5 files changed, 6 insertions(+), 111 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index cda0c1a8f40..9a89d520781 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -427,9 +427,7 @@ class NVF_API TensorView : public Val { //! Transforms the innermost iterdomains according to the given mma swizzle, //! this should be used on the tvs that are inputs of a MmaOp or are loaded //! using TMA. - void applyMmaSwizzleForTMALoad( - MmaInputSmemSwizzle swizzle, - bool permute_outer_dim = true); + void applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle); //! Returns if this tensor view has swizzle operator on its tensor domain. //! This is the temporary flag for indicating that the new swizzle diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 3b566d195a8..a38bebffcc1 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -929,14 +929,9 @@ void MmaSwizzler::parallelizeAsBulkSkippingFirstIDs( } } -// Please note that we currently do not fully support -// not splitting the outer dimension. This only works when -// the inner-dimension is not split, that is the inner dim -// is less or equal to the swizzle size (in bytes). void MmaSwizzler::scheduleTMALoadForMma( TensorView* tv, - MmaInputSmemSwizzle swizzle, - bool permute_outer_dim) { + MmaInputSmemSwizzle swizzle) { // In the comments below I have kept K as the outer dimension. That is // just to have a concrete running example - it can be inner or outer. @@ -968,16 +963,7 @@ void MmaSwizzler::scheduleTMALoadForMma( // [NO, K, NI] -> // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] tv->swizzleTMABox(swizzle); - - // If the outer dim is split, then we pull out KO to be outside NO - // and KO and NO are both not marked bulk parallel, else NO is outer - // and only NO is not marked bulk parallel. - if (permute_outer_dim) { - // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] -> - // [KO(2), NO(2), KIO(2), KII(4), NIO(2), NII(8)] - tv->reorder({{-6, -5}}); - } - num_ids_to_skip += permute_outer_dim ? 2 : 1; + num_ids_to_skip += 1; } parallelizeAsBulkSkippingFirstIDs(tv, num_ids_to_skip); diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index d4ab8a7f327..7bab9d4a9f7 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -233,8 +233,7 @@ class MmaSwizzler { //! outermost. static void scheduleTMALoadForMma( TensorView* tv, - MmaInputSmemSwizzle swizzle, - bool permute_outer_dim = true); + MmaInputSmemSwizzle swizzle); //! Parallelize all dims as bulk expect the first dims mentioned in the second //! param. diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index d9fedc5dfd4..5c480d7b221 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -1388,9 +1388,7 @@ void TensorView::swizzleTMABox(MmaInputSmemSwizzle swizzle) { this->swizzle(SwizzleType::XOR, -4, -2); } -void TensorView::applyMmaSwizzleForTMALoad( - MmaInputSmemSwizzle swizzle, - bool permute_outer_dim) { +void TensorView::applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle) { NVF_ERROR( getMemoryType() == MemoryType::Shared, "Shared memory swizzle is only supported for shared memory"); @@ -1398,8 +1396,7 @@ void TensorView::applyMmaSwizzleForTMALoad( definition()->as()->opType() == LoadStoreOpType::CpAsyncBulkTensorTile, "Operation requires a TMA operation"); - mma_utils::MmaSwizzler::scheduleTMALoadForMma( - this, swizzle, permute_outer_dim); + mma_utils::MmaSwizzler::scheduleTMALoadForMma(this, swizzle); } void TensorView::commitLeafToLogical() { diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 271c662cd12..ce3b6806635 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -878,91 +878,6 @@ TEST_P(HopperRS, SingleTileWithTMALoadStore) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } -TEST_P(HopperRS, SingleTileWithTMALoadOuterDimNotSplit) { - if (layout == MmaLayout::TT) { - GTEST_SKIP() << "Skipping test as we only handle TN layout in this test"; - } - - Fusion fusion; - FusionGuard fg(&fusion); - - auto shapes = matmulAtInputShape3DHopperRS( - getM(macro), getN(macro), getK(macro), layout); - - auto tv0 = makeContigConcreteTensor(shapes.first, dtype); - auto tv1 = makeContigConcreteTensor(shapes.second, dtype); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Just doing a gmem->register copy - tv0 = set(tv0); - // Just doing a gmem->smem copy - tv1 = set(tv1); - tv1->setMemoryType(MemoryType::Shared); - tv1->definition()->as()->setOpType( - LoadStoreOpType::CpAsyncBulkTensorTile); - - auto tv2 = fusedMultiplySum(tv0, tv1, {layout == MmaLayout::TT ? 1 : 2}); - - fusion.addOutput(tv2); - - auto mma_ops = ir_utils::getOpsOfType(&fusion); - NVF_CHECK( - 1 == mma_ops.size(), - "Invalid number of MmaOp instances in fusion definition, expected 1, got ", - mma_ops.size()); - mma_ops.front()->setMacro(macro); - - auto tv2c = tv2->cacheBefore(); - - matmul_utils::moveInnerBroadcastLeft(tv0); - tv0->applyMmaSwizzle(MmaOperand::A); - - tv0->merge(1); - tv0->merge(1); - tv0->axis(1)->parallelize(ParallelType::TIDx); - - // In this case we don't split the outer dimension, thus having - // fewer TMA loads. - tv1->applyMmaSwizzleForTMALoad(swizzle_b, /* don't split outer dim*/ false); - - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv2c->getLoopDomain()); - tv2c->setAllocationDomain(s.as(), true); - // Note: according to internal doc "Representing ldmatrix", we need both a - // read domain and a write domain to correctly represent MmaOp. Without this - // new mechanism, there is no correct loop domain, and the only choices are - // either we want to represent the smem read well, or represent the register - // write well. We choose to represent the smem read well here. Likely, this - // means we will not be able to have multiple tiles in register, but we can - // workaround this by always inlining the MmaOp most. We should fix this - // after we implemented the new read/write domain mechanism. - tv2c->axis(-1)->parallelize(ParallelType::Mma); - tv2c->axis(-2)->parallelize(ParallelType::Mma); - tv2c->axis(-3)->parallelize(ParallelType::Mma); - } - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv2->getLoopDomain()); - tv2->setLoopDomain(s.as()); - } - - auto inputs = matmulAtInput3DHopperRS( - getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); - - FusionExecutor fe; - fe.compileFusion( - &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams); - - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.squeeze().to(at::kFloat), - inputs.second.squeeze().to(at::kFloat), - layout); - EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); -} - TEST_P(HopperRS, MultipleTile) { Fusion fusion; FusionGuard fg(&fusion);