diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index fa185665253..528f1a0b616 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3687,6 +3687,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { constexpr auto swizzle = MmaInputSmemSwizzle::B128; const auto dtype = DataType::Half; + constexpr bool use_smem_epilogue = false; + constexpr int64_t stages = 4; constexpr int64_t prefetch = 3; const int64_t cta_m = 2 * getM(macro); @@ -3705,7 +3707,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { tv2->commitLeafToLogical(); auto tv3 = castOp(DataType::Half, tv2); - fusion.addOutput(tv3); auto mma_ops = ir_utils::getOpsOfType(&fusion); @@ -3722,7 +3723,21 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { tv0c->setMemoryType(MemoryType::Shared); auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); tv1c->setMemoryType(MemoryType::Shared); - auto tv3c = tv3->cacheBefore(); + + TensorView *tv3c = nullptr, *tv3_shmem = nullptr; + if (use_smem_epilogue) { + tv3_shmem = tv3->cacheBefore(); + tv3c = tv3_shmem->cacheBefore(); + tv3_shmem->setMemoryType(MemoryType::Shared); + tv3c->setMemoryType(MemoryType::Local); + tv3_shmem->definition()->as()->setOpType( + LoadStoreOpType::StMatrix); + tv3->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + } else { + tv3c = tv3->cacheBefore(); + tv3c->setMemoryType(MemoryType::Local); + } // gmem [K, M, 1] -TMA-> smem [K, M, 1] // gmem [K, 1, N] -TMA-> smem [K, 1, N] @@ -3775,12 +3790,30 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { tv2->axis(-3)->parallelize(ParallelType::Mma); } - for (auto tv : {tv3c, tv3}) { + if (!use_smem_epilogue) { + for (auto tv : {tv3c, tv3}) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv->getLoopDomain()); + tv->setLoopDomain(s.as()); + } + } else { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv->getLoopDomain()); - tv->setLoopDomain(s.as()); + tv3c->getLoopDomain()); + tv3c->setLoopDomain(s.as()); + tv3c->setAllocationDomain(s.as(), true); + + // We'll use stmatrix.x4 to store from reg to shared memory + fusion.manage("st_matrix_m_tile", 16); + fusion.manage("st_matrix_n_tile", 16); + fusion.manage("st_matrix_m", getM(macro)); + fusion.manage("st_matrix_n", getN(macro)); + + // This internally calls + // mma_utils::MmaSwizzler::scheduleMmaOutputAllocation + mma_utils::scheduleStMatrixForMmaOutput(tv3_shmem, 16, 16); + + mma_utils::scheduleTMAStoreForMmaOutput(tv3, M, N); } - tv3->axis(-1)->parallelize(ParallelType::Vectorize); inlineMost();