Skip to content

Commit

Permalink
icreate test
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 11, 2024
1 parent 89c47f6 commit 99aace1
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 75 deletions.
7 changes: 7 additions & 0 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <id_model/schedule.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <ir/graphviz.h>
#include <scheduler/ampere_multi_matmul.h>
#include <scheduler/debug_utils.h>
#include <scheduler/matmul.h>
Expand Down Expand Up @@ -1137,6 +1138,8 @@ void AmpereMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {

void AmpereMultipleMatmulScheduler::scheduleEpilogue() {
std::vector<TensorView*> output_tvs;
IrGraphGenerator::print(
fusion_, "a_amp.dot", IrGraphGenerator::DetailLevel::Basic);
for (Val* v : fusion_->outputs()) {
if (auto tv = dynamic_cast<TensorView*>(v)) {
output_tvs.push_back(tv);
Expand Down Expand Up @@ -1206,6 +1209,8 @@ void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
cached_tvs.push_back(c->cacheAfter());
}

IrGraphGenerator::print(
fusion_, "a_cache_after.dot", IrGraphGenerator::DetailLevel::Basic);
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

Expand All @@ -1224,6 +1229,8 @@ void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
IrGraphGenerator::print(
fusion_, "a_cache_clear.dot", IrGraphGenerator::DetailLevel::Basic);
}
}

Expand Down
139 changes: 86 additions & 53 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <disjoint_set.h>
#include <id_model/schedule.h>
#include <instrumentation.h>
#include <ir/graphviz.h>
#include <ir/utils.h>
#include <scheduler/debug_utils.h>
#include <scheduler/hopper_multi_matmul.h>
Expand Down Expand Up @@ -499,28 +500,30 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

// Block Schedule and Parallelize
blockTileTensors({dc, d});
parallelizeBlocks({dc, d});

// Apply mma common transformation
for (auto tv : {dc, d}) {
// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}
blockTileTensors({d});
parallelizeBlocks({d});
d->split(-2, getM(params_->mma_macro));
d->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
d->reorder({{-3, -2}});
d->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
d->setLoopDomain(s.as<IterDomain*>());
d->axis(-5)->parallelize(ParallelType::TIDy);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

d->axis(-1)->parallelize(ParallelType::Vectorize);
}
scheduleFusionInputsForEpilogue();
} else {
constexpr int64_t stmatrix_tile_m = 16;
constexpr int64_t stmatrix_tile_n = 16;
Expand Down Expand Up @@ -558,30 +561,42 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
blockTileTensors({dc, d_smem, d});
parallelizeBlocks({dc, d_smem, d});

// Apply mma common transformation
for (auto tv : {dc, d_smem, d}) {
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
// blockTileTensors({x, dc, d_smem, d});
// parallelizeBlocks({x, dc, d_smem, d});
blockTileTensors({d});
parallelizeBlocks({d});

// Apply mma transformation
// Original: [..., Mo, No, Mi, Ni]
d->split(-2, getM(params_->mma_macro));
d->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
d->reorder({{-3, -2}});
// d After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
d->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
d->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand All @@ -592,6 +607,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// Schedule global memory output; Output from TMA Store
mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
}
scheduleFusionInputsForEpilogue();
}
}

Expand All @@ -615,25 +631,42 @@ void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}
IrGraphGenerator::print(
fusion_, "a_cache_after.dot", IrGraphGenerator::DetailLevel::Basic);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

if (!params_->use_smem_epilogue) {
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);
std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);
} else {
auto* d_smem = output_d->definition()->input(0)->as<TensorView>();
auto* dc = d_smem->definition()->input(0)->as<TensorView>();
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
c_tvs,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

// if (params_->use_smem_epilogue) {
// // In cases where smem epilogue feature is enabled, the vectorization
// // of domains will be propagated to fusion inputs that are epilogue
// // inputs, this may result in unaligned memory reads. Vectorization is
// // explicitly excluded form parallelization types to avoid this issue.
// // This should be changed when vectorization analysis is available and
// // enabled for matmul scheduler.
// parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
// }
// scheduler_utils::parallelizeAllLike(
// output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
IrGraphGenerator::print(
fusion_, "a_cache_clear.dot", IrGraphGenerator::DetailLevel::Basic);
}
}

Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void scheduleMultipleMatmuls(Fusion* fusion, const MatmulParams* params) {
// conditions below.
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
// AmpereMultipleMatmulScheduler(fusion, params).run();
if (cc >= 75 && cc < 90) {
AmpereMultipleMatmulScheduler(fusion, params).run();
} else if (cc >= 90 && cc < 100) {
Expand Down
Loading

0 comments on commit 99aace1

Please sign in to comment.