diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 91048d3374c..cedc7d262d5 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -520,27 +520,36 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { d_smem->setMemoryType(MemoryType::Shared); // Set LoadStoreOp + // TODO: extend support when mma is not cast to half + NVF_ERROR( + dc->dtype() == DataType::Half, + "We support smem_epilogue on hopper only when the output of mma is cast to half"); + d_smem->definition()->as()->setOpType( LoadStoreOpType::StMatrix); d->definition()->as()->setOpType( LoadStoreOpType::CpAsyncBulkTensorTile); - // Block Schedule and Parallelize + // Apply the common transforms to dc, d_smem, d + // After these transforms we schedule the inner two non-reduction loops + // (instruction tile) of dc and propagate is back till the outputs of mma. blockTileTensors(tvs_to_schedule); parallelizeBlocks(tvs_to_schedule); - - // Apply mma common transformation for (auto tv : tvs_to_schedule) { transformLikeMmaOutput(tv, /*is_mma_result=*/false); } - // Schedule register cache; Output from epilogue - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - dc->getLoopDomain()); - dc->setLoopDomain(s.as()); - dc->setAllocationDomain(s.as(), true); - } + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + dc->getLoopDomain()); + dc->setLoopDomain(s.as()); + dc->setAllocationDomain(s.as(), true); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + dc, + -1, + propagate_to, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 0ffde4364c1..c9e56706ca6 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3379,10 +3379,6 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) { // TODO: Remove this test once the architecture agnostic can be // run on hopper. TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) { - if (use_smem_epilogue) { - GTEST_SKIP() - << "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now"; - } const auto& [A, B] = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); const auto& C = matmulAtInput2D(