Skip to content

Commit

Permalink
Schedule epilogue (for Hopper Matmul) by propagation backward from ou…
Browse files Browse the repository at this point in the history
…tput - smem epilogue not supported. (#3580)

This adds support for scheduling epilogue for the hopper matmul
scheduler.
We don't support smem epilogue as yet.

We also don't honor the vectorization_factor as yet for the store to
output. That'll be covered in a separate PR.
  • Loading branch information
protonu authored Dec 14, 2024
1 parent b68a7e4 commit d53be45
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 143 deletions.
171 changes: 37 additions & 134 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,107 +434,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
}
}

void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
const MatMulTileOptions& gemm_tile = params_->tile_sizes;
const int64_t vectorization_factor = params_->supported_vec_size.epilogue;
// input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n]
mma_utils::checkConcreteStaticDim(c->axis(-2));
mma_utils::checkConcreteStaticDim(c->axis(-1));
const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as<int64_t>();
const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
tile_size_m == gemm_tile.cta_tile.m,
"Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.m,
", actual: ",
tile_size_m);
NVF_ERROR(
tile_size_n == gemm_tile.cta_tile.n,
"Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.n,
", actual: ",
tile_size_n);
const int64_t tot_elements = tile_size_m * tile_size_n;
constexpr int64_t warp_size = 32l;
const int64_t tidx = warp_size;
const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n;
const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m;
// step-1, merge last 2 dims
c->merge(-2);
// [Mo, No, m*n]

// step-2, set vectorization to maximum
// We have fixed tidx, tidy, and tidz, so we need to make sure that the
// output tensor is divisible by tidx * tidy * tidz * vectorization_factor
NVF_ERROR(
tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0,
"Output tensor cannot be fully vectorized! tot_elements:",
tot_elements,
", tidx: ",
tidx,
", tidy: ",
tidy,
", tidz: ",
tidz,
", vectorization_factor: ",
vectorization_factor);
c->split(-1, vectorization_factor);
c->axis(-1)->parallelize(ParallelType::Vectorize);
// [Mo, No, m*n/vect, vect]

// step-3, Split out a warp for TIDx
c->split(-2, tidx);
c->axis(-2)->parallelize(ParallelType::TIDx);
// [Mo, No, m*n/vect/TIDx, TIDx, vect]

// step-4, Split out for TIDy and TIDz
// TIDy = cta_tile_n/warp_tile_n
// TIDz = cta_tile_m/warp_tile_m
c->split(-3, tidy);
c->axis(-3)->parallelize(ParallelType::TIDy);

c->split(-4, tidz);
c->axis(-4)->parallelize(ParallelType::TIDz);
// [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect]
void HopperMultipleMatmulScheduler::scheduleEpilogue() {
std::vector<TensorView*> cached_tvs;

for (TensorView* mma_result : mma_results_) {
// step-5, Parallel first 2 dims same as mma_result
scheduler_utils::parallelizeAllLike(
mma_result,
2,
{c},
{ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz});
// Propagate to (not including) the splitk output if there is a splitk
// else this is just mma_results_
std::vector<TensorView*> propagate_to =
splitk_sums_.empty() ? mma_results_ : splitk_sums_;
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
// Load/cache the epilogue inputs if there are any.
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}
propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
}
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// TODO: schedule epilogue by propagation backward from dc
if (!params_->use_smem_epilogue) {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

std::vector<TensorView*> tvs_to_schedule{d};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

// Block Schedule and Parallelize
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
// Schedule the output TV and propagate it back to the outputs of the Mma
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
d->setLoopDomain(s.as<IterDomain*>());

// TODO: We need to check bank conflicts in this path.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// We don't respect vectorization_factor as yet. We vectorize the
// inner-dim with extent 2.
// TODO: support vectorization_factor.
d->axis(-1)->parallelize(ParallelType::Vectorize);
if (!cached_tvs.empty()) {
scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
}
}
} else {
constexpr int64_t stmatrix_tile_m = 16;
Expand Down Expand Up @@ -609,48 +554,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
}
}

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
std::vector<TensorView*> cached_tvs;

// Handling transformations in fusion input tvs with assigned EPILOGUE_INPUT
// role by propagating fusion output transformations through cached views
// of EPILOGUE_INPUT fusion input tvs and by setting vectorization of the
// inner most iterdomain of these cached views
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);

// The system supports only scenario where there is only one fusion output
// with assigned OUTPUT role, this condition is already verified so there
// is no need for an additional checks here
auto output_d = tensor_roles_.at(MatmulTensorRole::OUTPUT).front();
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
}
}

void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
if (params_->splitk_factor == 1) {
return;
Expand Down
7 changes: 0 additions & 7 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {

void scheduleMmaResults();

void scheduleOutputTensor(TensorView* c);

void scheduleEpilogue();

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void scheduleFusionInputsForEpilogue();

void scheduleSplitKSum();

void setUpInlining();
Expand Down
76 changes: 74 additions & 2 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3296,8 +3296,7 @@ class HopperMatmulSchedulerTest
KernelExecutor ke;
ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
auto nvf_out = ke.run(inputs);
// NOTE Relax tolerances for split-k case
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3));
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2));
}

protected:
Expand Down Expand Up @@ -3377,6 +3376,79 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
tref = atMatmul(A.squeeze(), B.squeeze(), layout);
}

// TODO: Remove this test once the architecture agnostic can be
// run on hopper.
TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
if (use_smem_epilogue) {
GTEST_SKIP()
<< "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now";
}
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K);
inputs = {A, B, C};

TensorView* tv0 = nullptr;
TensorView* tv1 = nullptr;
std::unordered_map<int64_t, int64_t> old2new;
int64_t k_axis = 0;

switch (layout) {
case MmaLayout::TT:
// Inner dims KN, order is MKN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {{-2, -1}, {-1, -2}};
k_axis = -2;
break;
case MmaLayout::TN:
// Inner dims KK, order is MNK
tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {};
k_axis = -1;
break;
case MmaLayout::NT:
// Inner dims MN, order is KMN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
old2new = {{-3, -1}};
k_axis = -3;
break;
case MmaLayout::NN:
// Inner dims MK, order is NKM
tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
old2new = {{-1, -3}};
k_axis = -2;
break;
}
TensorView* tv2 = makeContigConcreteTensor({-1}, dtype);

fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis});

// Reorder the accumulator as [M, N, K]
tv3->reorder(old2new);
tv3->commitLeafToLogical();

auto* tv4 = maybeCastOp(DataType::Float, tv2);
auto* tv5 = biasEpilogue(tv3, tv4);
auto* tv6 = neg(tv5);
auto* tv7 = castOp(dtype, tv6);
fusion->addOutput(tv7);

tref = atBiasEpilogue(
atMatmul(A.squeeze(), B.squeeze(), layout),
C.to(data_type_to_aten(DataType::Float)))
.neg_()
.to(data_type_to_aten(DataType::Half));
}

INSTANTIATE_TEST_SUITE_P(
General,
HopperMatmulSchedulerTest,
Expand Down

0 comments on commit d53be45

Please sign in to comment.