diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 1df6276ab0e..31fa5e70ef6 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -504,44 +504,44 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { auto* d = dv->as(); NVF_ERROR(d->definition() && d->definition()->isA()); auto* dc = d->definition()->input(0)->as(); - // NOTE: cacheBefore does not work with blockTileTensors TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set); - std::vector tvs_to_schedule{d, d_smem}; - if (std::find(mma_results_.begin(), mma_results_.end(), dc) == - mma_results_.end()) { - // Skip scheduling dc if it is an mma_result. This can happen if we are - // not casting back to half-precision in the output - tvs_to_schedule.push_back(dc); - } - // Set MemoryType dc->setMemoryType(MemoryType::Local); d_smem->setMemoryType(MemoryType::Shared); // Set LoadStoreOp - d_smem->definition()->as()->setOpType( - LoadStoreOpType::StMatrix); + if (dc->dtype() == DataType::Half && !dc->definition()->isA()) { + d_smem->definition()->as()->setOpType( + LoadStoreOpType::StMatrix); + } d->definition()->as()->setOpType( LoadStoreOpType::CpAsyncBulkTensorTile); - // Block Schedule and Parallelize - blockTileTensors(tvs_to_schedule); - parallelizeBlocks(tvs_to_schedule); + blockTileTensors({d}); + parallelizeBlocks({d}); + transformLikeMmaOutput(d, /*is_mma_result=*/false); - // Apply mma common transformation - for (auto tv : tvs_to_schedule) { - transformLikeMmaOutput(tv, /*is_mma_result=*/false); - } + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + d, + -1, + {dc}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); - // Schedule register cache; Output from epilogue - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - dc->getLoopDomain()); - dc->setLoopDomain(s.as()); - dc->setAllocationDomain(s.as(), true); - } + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + dc->getLoopDomain()); + dc->setLoopDomain(s.as()); + dc->setAllocationDomain(s.as(), true); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + dc, + -1, + propagate_to, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);