Skip to content

Commit

Permalink
create transformLikeMmaOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Dec 12, 2024
1 parent e18741c commit 8cbb4ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
31 changes: 17 additions & 14 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
// TODO Add constraints

// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
ValGroup vg = graph_->toGroup(id);
auto it = id_roles_.find(vg);
Expand Down Expand Up @@ -644,20 +660,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();

// [..., Mo, No, Mi, Ni]
splitk_sum->split(-2, getM(params_->mma_macro));
splitk_sum->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
splitk_sum->reorder({{-3, -2}});
splitk_sum->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
splitk_sum->axis(-5)->parallelize(ParallelType::TIDy);

// splitk_sum->reorder({{2, -2}});
transformLikeMmaOutput(splitk_sum);
splitk_sum->axis(2)->parallelize(ParallelType::BIDz);
splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize);
}
Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Return MatmulDimRole for IterDomain
MatmulDimRole findMatmulDimRole(IterDomain* id);

// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;

Expand Down

0 comments on commit 8cbb4ab

Please sign in to comment.