Skip to content

Commit

Permalink
adding support for smem_epilogue
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 12, 2024
1 parent 8bf8997 commit b9df449
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 135 deletions.
204 changes: 80 additions & 124 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,76 +420,77 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
}
}

void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
// Block Schedule and Parallelize
blockTileTensors({c});
parallelizeBlocks({c});

// Apply mma common transformation
c->split(-2, getM(params_->mma_macro));
c->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
c->reorder({{-3, -2}});
c->merge(-4);

// [...., Mii, Nii] ->
// [..., Mii/16, Miioi(2), Miii(8), Nii/8, Niio(4), Niii(2)] ->
// [.., Mii/16, Miii(8), Niio(4), Nii/8, Miioi(2), Niii(2) ]
// [..., Mii/16 * Miii(8) * Niio(4), Nii/8, Miioi(2), Niii(2) ]
// Mii/16 * Miii(8) * Niio(4) is 128 for Hopper and this is parallelized as
// TIDx.
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(c->getLoopDomain());
c->setLoopDomain(s.as<IterDomain*>());
c->axis(-5)->parallelize(ParallelType::TIDy);
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
if (!params_->use_smem_epilogue) {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());

// Schedule the output TV and propagate it back to the outputs of the Mma
// op.
scheduleOutputTensor(d);
constexpr int64_t stmatrix_tile_m = 16;
constexpr int64_t stmatrix_tile_n = 16;

// TODO: Support tma tile sizes that are a multiple of mma_macro.
// The wgmma operation creates an output matrix of mma_macro size. The TMA
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);

auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
std::vector<TensorView*> cached_tvs;
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}

auto applyCommonMmaTransforms = [this](TensorView* tv) {
// Apply mma common transformation
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
};

auto scheduleAsMmaOutputAllocation = [](TensorView* tv) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
};

if (params_->use_smem_epilogue) {
fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
fusion_->manage("st_matrix_m", tma_m);
fusion_->manage("st_matrix_n", tma_n);
}

std::vector<TensorView*> propagate_to = mma_results_;
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
}

for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());

if (!params_->use_smem_epilogue) {
blockTileTensors({d});
parallelizeBlocks({d});
applyCommonMmaTransforms(d);
scheduleAsMmaOutputAllocation(d);
d->axis(-5)->parallelize(ParallelType::TIDy);
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// We don't respect vectorization_factor as yet. We vectorize the
// inner-dim with extent 2.
// TODO: support vectorization_factor.
d->axis(-1)->parallelize(ParallelType::Vectorize);
}
scheduleFusionInputsForEpilogue();
} else {
constexpr int64_t stmatrix_tile_m = 16;
constexpr int64_t stmatrix_tile_n = 16;

// TODO: Support tma tile sizes that are a multiple of mma_macro.
// The wgmma operation creates an output matrix of mma_macro size. The TMA
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
fusion_->manage("st_matrix_m", tma_m);
fusion_->manage("st_matrix_n", tma_n);

// Manually schedule register cache and output TensorView
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
} else {
auto* dc = d->definition()->input(0)->as<TensorView>();

// NOTE: cacheBefore does not work with blockTileTensors
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

Expand All @@ -503,31 +504,28 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
blockTileTensors({dc, d_smem, d});
parallelizeBlocks({dc, d_smem, d});

// Apply mma common transformation
for (auto tv : {dc, d_smem, d}) {
// Original: [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
blockTileTensors({d});
parallelizeBlocks({d});
applyCommonMmaTransforms(d);
d->axis(-3)->parallelize(ParallelType::TIDy);

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
{dc},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

scheduleAsMmaOutputAllocation(dc);
dc->setAllocationDomain(dc->getLoopDomain(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand All @@ -541,48 +539,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
}
}

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
std::vector<TensorView*> cached_tvs;

// Handling transformations in fusion input tvs with assigned EPILOGUE_INPUT
// role by propagating fusion output transformations through cached views
// of EPILOGUE_INPUT fusion input tvs and by setting vectorization of the
// inner most iterdomain of these cached views
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);

// The system supports only scenario where there is only one fusion output
// with assigned OUTPUT role, this condition is already verified so there
// is no need for an additional checks here
auto output_d = tensor_roles_.at(MatmulTensorRole::OUTPUT).front();
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
}
}

void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
if (params_->splitk_factor == 1) {
return;
Expand Down
7 changes: 0 additions & 7 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {

void scheduleMmaResults();

void scheduleOutputTensor(TensorView* c);

void scheduleEpilogue();

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void scheduleFusionInputsForEpilogue();

void scheduleSplitKSum();

void setUpInlining();
Expand Down
4 changes: 0 additions & 4 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3298,10 +3298,6 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
}

TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
if (use_smem_epilogue) {
GTEST_SKIP()
<< "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now";
}
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
Expand Down

0 comments on commit b9df449

Please sign in to comment.