Skip to content

Commit

Permalink
apply more places
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Dec 12, 2024
1 parent 8cbb4ab commit f325260
Showing 1 changed file with 11 additions and 25 deletions.
36 changes: 11 additions & 25 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ namespace nvfuser {
void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
// TODO Add constraints

// [..., Mo, No, Mi, Ni]
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down Expand Up @@ -523,17 +522,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Apply mma common transformation
for (auto tv : {dc, d}) {
// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
transformLikeMmaOutput(tv);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}
d->axis(-1)->parallelize(ParallelType::Vectorize);
}
Expand Down Expand Up @@ -579,16 +571,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Apply mma common transformation
for (auto tv : {dc, d_smem, d}) {
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
transformLikeMmaOutput(tv);
}

// Schedule register cache; Output from epilogue
Expand Down Expand Up @@ -661,6 +644,9 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
splitk_sum->axis(2)->parallelize(ParallelType::BIDz);
splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize);
}
Expand Down

0 comments on commit f325260

Please sign in to comment.