Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][DO NOT REVIEW] Schedule epilogue for Hopper - without smem epilogue #3565

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <id_model/schedule.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <ir/graphviz.h>
#include <scheduler/ampere_multi_matmul.h>
#include <scheduler/debug_utils.h>
#include <scheduler/matmul.h>
Expand Down Expand Up @@ -1137,6 +1138,8 @@ void AmpereMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {

void AmpereMultipleMatmulScheduler::scheduleEpilogue() {
std::vector<TensorView*> output_tvs;
IrGraphGenerator::print(
fusion_, "a_amp.dot", IrGraphGenerator::DetailLevel::Basic);
for (Val* v : fusion_->outputs()) {
if (auto tv = dynamic_cast<TensorView*>(v)) {
output_tvs.push_back(tv);
Expand Down Expand Up @@ -1206,6 +1209,8 @@ void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
cached_tvs.push_back(c->cacheAfter());
}

IrGraphGenerator::print(
fusion_, "a_cache_after.dot", IrGraphGenerator::DetailLevel::Basic);
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

Expand All @@ -1224,6 +1229,8 @@ void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
IrGraphGenerator::print(
fusion_, "a_cache_clear.dot", IrGraphGenerator::DetailLevel::Basic);
}
}

Expand Down
139 changes: 86 additions & 53 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <disjoint_set.h>
#include <id_model/schedule.h>
#include <instrumentation.h>
#include <ir/graphviz.h>
#include <ir/utils.h>
#include <scheduler/debug_utils.h>
#include <scheduler/hopper_multi_matmul.h>
Expand Down Expand Up @@ -499,28 +500,30 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

// Block Schedule and Parallelize
blockTileTensors({dc, d});
parallelizeBlocks({dc, d});

// Apply mma common transformation
for (auto tv : {dc, d}) {
// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}
blockTileTensors({d});
parallelizeBlocks({d});
d->split(-2, getM(params_->mma_macro));
d->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
d->reorder({{-3, -2}});
d->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
d->setLoopDomain(s.as<IterDomain*>());
d->axis(-5)->parallelize(ParallelType::TIDy);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

d->axis(-1)->parallelize(ParallelType::Vectorize);
}
scheduleFusionInputsForEpilogue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this path, since we are not using TMA we should respect the params_->supported_vec_size.epilogue parameter to set the vectorization width. See scheduleOutputTensor in this file, which we should either update and use or remove.

In the TMA path (the else branch) we might not be able to use scheduleFusionInputsForEpilogue as is because it is currently propagating the schedule back from the output d, which is scheduled for a TMA store not a vectorized load. We could potentially just do that propagation then check the innermost dims and merge/split/unroll as needed in order to hit the target vectorization for the epilogue.

} else {
constexpr int64_t stmatrix_tile_m = 16;
constexpr int64_t stmatrix_tile_n = 16;
Expand Down Expand Up @@ -558,30 +561,42 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
blockTileTensors({dc, d_smem, d});
parallelizeBlocks({dc, d_smem, d});

// Apply mma common transformation
for (auto tv : {dc, d_smem, d}) {
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
// blockTileTensors({x, dc, d_smem, d});
// parallelizeBlocks({x, dc, d_smem, d});
blockTileTensors({d});
parallelizeBlocks({d});

// Apply mma transformation
// Original: [..., Mo, No, Mi, Ni]
d->split(-2, getM(params_->mma_macro));
d->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
d->reorder({{-3, -2}});
// d After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
d->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
d->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand All @@ -592,6 +607,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// Schedule global memory output; Output from TMA Store
mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
}
scheduleFusionInputsForEpilogue();
}
}

Expand All @@ -615,25 +631,42 @@ void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}
IrGraphGenerator::print(
fusion_, "a_cache_after.dot", IrGraphGenerator::DetailLevel::Basic);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

if (!params_->use_smem_epilogue) {
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);
std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);
} else {
auto* d_smem = output_d->definition()->input(0)->as<TensorView>();
auto* dc = d_smem->definition()->input(0)->as<TensorView>();
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
c_tvs,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

// if (params_->use_smem_epilogue) {
// // In cases where smem epilogue feature is enabled, the vectorization
// // of domains will be propagated to fusion inputs that are epilogue
// // inputs, this may result in unaligned memory reads. Vectorization is
// // explicitly excluded form parallelization types to avoid this issue.
// // This should be changed when vectorization analysis is available and
// // enabled for matmul scheduler.
// parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
// }
// scheduler_utils::parallelizeAllLike(
// output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
IrGraphGenerator::print(
fusion_, "a_cache_clear.dot", IrGraphGenerator::DetailLevel::Basic);
}
}

Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void scheduleMultipleMatmuls(Fusion* fusion, const MatmulParams* params) {
// conditions below.
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
// AmpereMultipleMatmulScheduler(fusion, params).run();
if (cc >= 75 && cc < 90) {
AmpereMultipleMatmulScheduler(fusion, params).run();
} else if (cc >= 90 && cc < 100) {
Expand Down
Loading
Loading