Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for scheduling the epilogue computation when smem_epilogue parameter is true #3581

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
225 changes: 64 additions & 161 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,107 +434,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
}
}

void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
const MatMulTileOptions& gemm_tile = params_->tile_sizes;
const int64_t vectorization_factor = params_->supported_vec_size.epilogue;
// input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n]
mma_utils::checkConcreteStaticDim(c->axis(-2));
mma_utils::checkConcreteStaticDim(c->axis(-1));
const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as<int64_t>();
const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
tile_size_m == gemm_tile.cta_tile.m,
"Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.m,
", actual: ",
tile_size_m);
NVF_ERROR(
tile_size_n == gemm_tile.cta_tile.n,
"Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.n,
", actual: ",
tile_size_n);
const int64_t tot_elements = tile_size_m * tile_size_n;
constexpr int64_t warp_size = 32l;
const int64_t tidx = warp_size;
const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n;
const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m;
// step-1, merge last 2 dims
c->merge(-2);
// [Mo, No, m*n]

// step-2, set vectorization to maximum
// We have fixed tidx, tidy, and tidz, so we need to make sure that the
// output tensor is divisible by tidx * tidy * tidz * vectorization_factor
NVF_ERROR(
tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0,
"Output tensor cannot be fully vectorized! tot_elements:",
tot_elements,
", tidx: ",
tidx,
", tidy: ",
tidy,
", tidz: ",
tidz,
", vectorization_factor: ",
vectorization_factor);
c->split(-1, vectorization_factor);
c->axis(-1)->parallelize(ParallelType::Vectorize);
// [Mo, No, m*n/vect, vect]

// step-3, Split out a warp for TIDx
c->split(-2, tidx);
c->axis(-2)->parallelize(ParallelType::TIDx);
// [Mo, No, m*n/vect/TIDx, TIDx, vect]

// step-4, Split out for TIDy and TIDz
// TIDy = cta_tile_n/warp_tile_n
// TIDz = cta_tile_m/warp_tile_m
c->split(-3, tidy);
c->axis(-3)->parallelize(ParallelType::TIDy);

c->split(-4, tidz);
c->axis(-4)->parallelize(ParallelType::TIDz);
// [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect]
void HopperMultipleMatmulScheduler::scheduleEpilogue() {
std::vector<TensorView*> cached_tvs;

for (TensorView* mma_result : mma_results_) {
// step-5, Parallel first 2 dims same as mma_result
scheduler_utils::parallelizeAllLike(
mma_result,
2,
{c},
{ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz});
// Propagate to (not including) the splitk output if there is a splitk
// else this is just mma_results_
std::vector<TensorView*> propagate_to =
splitk_sums_.empty() ? mma_results_ : splitk_sums_;
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
// Load/cache the epilogue inputs if there are any.
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}
propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
}
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// TODO: schedule epilogue by propagation backward from dc
if (!params_->use_smem_epilogue) {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

std::vector<TensorView*> tvs_to_schedule{d};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

// Block Schedule and Parallelize
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
// Schedule the output TV and propagate it back to the outputs of the Mma
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
d->setLoopDomain(s.as<IterDomain*>());

// TODO: We need to check bank conflicts in this path.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// We don't respect vectorization_factor as yet. We vectorize the
// inner-dim with extent 2.
// TODO: support vectorization_factor.
d->axis(-1)->parallelize(ParallelType::Vectorize);
if (!cached_tvs.empty()) {
scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
}
}
} else {
constexpr int64_t stmatrix_tile_m = 16;
Expand All @@ -558,44 +503,44 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

// NOTE: cacheBefore does not work with blockTileTensors
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

std::vector<TensorView*> tvs_to_schedule{d, d_smem};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

// Set MemoryType
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
if (dc->dtype() == DataType::Half && !dc->definition()->isA<MmaOp>()) {
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
}
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Block Schedule and Parallelize
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

// Schedule register cache; Output from epilogue
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
{dc},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

Expand All @@ -609,48 +554,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
}
}

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
std::vector<TensorView*> cached_tvs;

// Handling transformations in fusion input tvs with assigned EPILOGUE_INPUT
// role by propagating fusion output transformations through cached views
// of EPILOGUE_INPUT fusion input tvs and by setting vectorization of the
// inner most iterdomain of these cached views
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);

// The system supports only scenario where there is only one fusion output
// with assigned OUTPUT role, this condition is already verified so there
// is no need for an additional checks here
auto output_d = tensor_roles_.at(MatmulTensorRole::OUTPUT).front();
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
}
}

void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
if (params_->splitk_factor == 1) {
return;
Expand Down
7 changes: 0 additions & 7 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {

void scheduleMmaResults();

void scheduleOutputTensor(TensorView* c);

void scheduleEpilogue();

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void scheduleFusionInputsForEpilogue();

void scheduleSplitKSum();

void setUpInlining();
Expand Down
72 changes: 70 additions & 2 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3296,8 +3296,7 @@ class HopperMatmulSchedulerTest
KernelExecutor ke;
ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
auto nvf_out = ke.run(inputs);
// NOTE Relax tolerances for split-k case
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3));
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2));
}

protected:
Expand Down Expand Up @@ -3377,6 +3376,75 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
tref = atMatmul(A.squeeze(), B.squeeze(), layout);
}

// TODO: Remove this test once the architecture agnostic can be
// run on hopper.
TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K);
inputs = {A, B, C};

TensorView* tv0 = nullptr;
TensorView* tv1 = nullptr;
std::unordered_map<int64_t, int64_t> old2new;
int64_t k_axis = 0;

switch (layout) {
case MmaLayout::TT:
// Inner dims KN, order is MKN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {{-2, -1}, {-1, -2}};
k_axis = -2;
break;
case MmaLayout::TN:
// Inner dims KK, order is MNK
tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {};
k_axis = -1;
break;
case MmaLayout::NT:
// Inner dims MN, order is KMN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
old2new = {{-3, -1}};
k_axis = -3;
break;
case MmaLayout::NN:
// Inner dims MK, order is NKM
tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
old2new = {{-1, -3}};
k_axis = -2;
break;
}
TensorView* tv2 = makeContigConcreteTensor({-1}, dtype);

fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis});

// Reorder the accumulator as [M, N, K]
tv3->reorder(old2new);
tv3->commitLeafToLogical();

auto* tv4 = maybeCastOp(DataType::Float, tv2);
auto* tv5 = biasEpilogue(tv3, tv4);
auto* tv6 = neg(tv5);
auto* tv7 = castOp(dtype, tv6);
fusion->addOutput(tv7);

tref = atBiasEpilogue(
atMatmul(A.squeeze(), B.squeeze(), layout),
C.to(data_type_to_aten(DataType::Float)))
.neg_()
.to(data_type_to_aten(DataType::Half));
}

INSTANTIATE_TEST_SUITE_P(
General,
HopperMatmulSchedulerTest,
Expand Down
Loading