Skip to content

Commit

Permalink
adding support for smem_epilogue
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 14, 2024
1 parent 5ed7e3c commit c0cfa4e
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,44 +504,44 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

// NOTE: cacheBefore does not work with blockTileTensors
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

std::vector<TensorView*> tvs_to_schedule{d, d_smem};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

// Set MemoryType
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
if (dc->dtype() == DataType::Half && !dc->definition()->isA<MmaOp>()) {
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
}
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
{dc},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand Down

0 comments on commit c0cfa4e

Please sign in to comment.