Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule epilogue (for Hopper Matmul) by propagation backward from output - smem epilogue not supported. #3580

Merged
merged 6 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 37 additions & 134 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,107 +434,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
}
}

void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
const MatMulTileOptions& gemm_tile = params_->tile_sizes;
protonu marked this conversation as resolved.
Show resolved Hide resolved
const int64_t vectorization_factor = params_->supported_vec_size.epilogue;
// input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n]
mma_utils::checkConcreteStaticDim(c->axis(-2));
mma_utils::checkConcreteStaticDim(c->axis(-1));
const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as<int64_t>();
const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
tile_size_m == gemm_tile.cta_tile.m,
"Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.m,
", actual: ",
tile_size_m);
NVF_ERROR(
tile_size_n == gemm_tile.cta_tile.n,
"Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.n,
", actual: ",
tile_size_n);
const int64_t tot_elements = tile_size_m * tile_size_n;
constexpr int64_t warp_size = 32l;
const int64_t tidx = warp_size;
const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n;
const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m;
// step-1, merge last 2 dims
c->merge(-2);
// [Mo, No, m*n]

// step-2, set vectorization to maximum
// We have fixed tidx, tidy, and tidz, so we need to make sure that the
// output tensor is divisible by tidx * tidy * tidz * vectorization_factor
NVF_ERROR(
tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0,
"Output tensor cannot be fully vectorized! tot_elements:",
tot_elements,
", tidx: ",
tidx,
", tidy: ",
tidy,
", tidz: ",
tidz,
", vectorization_factor: ",
vectorization_factor);
c->split(-1, vectorization_factor);
c->axis(-1)->parallelize(ParallelType::Vectorize);
// [Mo, No, m*n/vect, vect]

// step-3, Split out a warp for TIDx
c->split(-2, tidx);
c->axis(-2)->parallelize(ParallelType::TIDx);
// [Mo, No, m*n/vect/TIDx, TIDx, vect]

// step-4, Split out for TIDy and TIDz
// TIDy = cta_tile_n/warp_tile_n
// TIDz = cta_tile_m/warp_tile_m
c->split(-3, tidy);
c->axis(-3)->parallelize(ParallelType::TIDy);

c->split(-4, tidz);
c->axis(-4)->parallelize(ParallelType::TIDz);
// [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect]
void HopperMultipleMatmulScheduler::scheduleEpilogue() {
std::vector<TensorView*> cached_tvs;

for (TensorView* mma_result : mma_results_) {
// step-5, Parallel first 2 dims same as mma_result
scheduler_utils::parallelizeAllLike(
mma_result,
2,
{c},
{ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz});
// Propagate to (not including) the splitk output if there is a splitk
// else this is just mma_results_
protonu marked this conversation as resolved.
Show resolved Hide resolved
std::vector<TensorView*> propagate_to =
splitk_sums_.empty() ? mma_results_ : splitk_sums_;
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
protonu marked this conversation as resolved.
Show resolved Hide resolved
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
// Load/cache the epilogue inputs if there are any.
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}
propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
}
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// TODO: schedule epilogue by propagation backward from dc
if (!params_->use_smem_epilogue) {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

std::vector<TensorView*> tvs_to_schedule{d};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

// Block Schedule and Parallelize
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);

// Apply mma common transformation
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
// Schedule the output TV and propagate it back to the outputs of the Mma
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
d->setLoopDomain(s.as<IterDomain*>());

// TODO: We need to check bank conflicts in this path.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: we'll need to revisit this path for cases where TMA is disabled. We need to make sure it is doing coalesced stores here (can check with ncu). I imagine that we will almost always want to use TMA, but it's possible that we would skip it if it increases smem usage too much. In those cases we'll hit this path.

.propagateParallelType());

// We don't respect vectorization_factor as yet. We vectorize the
// inner-dim with extent 2.
// TODO: support vectorization_factor.
d->axis(-1)->parallelize(ParallelType::Vectorize);
if (!cached_tvs.empty()) {
scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
}
}
} else {
constexpr int64_t stmatrix_tile_m = 16;
Expand Down Expand Up @@ -609,48 +554,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
}
}

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
std::vector<TensorView*> cached_tvs;

// Handling transformations in fusion input tvs with assigned EPILOGUE_INPUT
// role by propagating fusion output transformations through cached views
// of EPILOGUE_INPUT fusion input tvs and by setting vectorization of the
// inner most iterdomain of these cached views
if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);

// The system supports only scenario where there is only one fusion output
// with assigned OUTPUT role, this condition is already verified so there
// is no need for an additional checks here
auto output_d = tensor_roles_.at(MatmulTensorRole::OUTPUT).front();
for (auto* c : c_tvs) {
cached_tvs.push_back(c->cacheAfter());
}

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
output_d, -1, c_tvs);

std::unordered_set<ParallelType> parallel_types = {};
if (params_->use_smem_epilogue) {
// In cases where smem epilogue feature is enabled, the vectorization
// of domains will be propagated to fusion inputs that are epilogue
// inputs, this may result in unaligned memory reads. Vectorization is
// explicitly excluded form parallelization types to avoid this issue.
// This should be changed when vectorization analysis is available and
// enabled for matmul scheduler.
parallel_types = allParallelTypesExcept({ParallelType::Vectorize});
}
scheduler_utils::parallelizeAllLike(
output_d, -1, cached_tvs, parallel_types);

// The cached EPILOGUE_INPUT tvs are not needed anymore
cached_tvs.clear();
}
}

void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
if (params_->splitk_factor == 1) {
return;
Expand Down
7 changes: 0 additions & 7 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {

void scheduleMmaResults();

void scheduleOutputTensor(TensorView* c);

void scheduleEpilogue();

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void scheduleFusionInputsForEpilogue();

void scheduleSplitKSum();

void setUpInlining();
Expand Down
76 changes: 74 additions & 2 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3296,8 +3296,7 @@ class HopperMatmulSchedulerTest
KernelExecutor ke;
ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
auto nvf_out = ke.run(inputs);
// NOTE Relax tolerances for split-k case
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3));
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2));
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
}

protected:
Expand Down Expand Up @@ -3377,6 +3376,79 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
tref = atMatmul(A.squeeze(), B.squeeze(), layout);
}

// TODO: Remove this test once the architecture agnostic can be
// run on hopper.
TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
protonu marked this conversation as resolved.
Show resolved Hide resolved
if (use_smem_epilogue) {
GTEST_SKIP()
<< "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now";
}
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K);
inputs = {A, B, C};

TensorView* tv0 = nullptr;
TensorView* tv1 = nullptr;
std::unordered_map<int64_t, int64_t> old2new;
int64_t k_axis = 0;

switch (layout) {
case MmaLayout::TT:
// Inner dims KN, order is MKN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {{-2, -1}, {-1, -2}};
k_axis = -2;
break;
case MmaLayout::TN:
// Inner dims KK, order is MNK
tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {};
k_axis = -1;
break;
case MmaLayout::NT:
// Inner dims MN, order is KMN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
old2new = {{-3, -1}};
k_axis = -3;
break;
case MmaLayout::NN:
// Inner dims MK, order is NKM
tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
old2new = {{-1, -3}};
k_axis = -2;
break;
}
TensorView* tv2 = makeContigConcreteTensor({-1}, dtype);

fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis});

// Reorder the accumulator as [M, N, K]
tv3->reorder(old2new);
tv3->commitLeafToLogical();

auto* tv4 = maybeCastOp(DataType::Float, tv2);
auto* tv5 = biasEpilogue(tv3, tv4);
auto* tv6 = neg(tv5);
auto* tv7 = castOp(dtype, tv6);
fusion->addOutput(tv7);

tref = atBiasEpilogue(
atMatmul(A.squeeze(), B.squeeze(), layout),
C.to(data_type_to_aten(DataType::Float)))
.neg_()
.to(data_type_to_aten(DataType::Half));
}

INSTANTIATE_TEST_SUITE_P(
General,
HopperMatmulSchedulerTest,
Expand Down
Loading