Skip to content

Commit

Permalink
adding a new unit test for mma+bias and propating schedules
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 12, 2024
1 parent 2749296 commit 8bf8997
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 91 deletions.
126 changes: 36 additions & 90 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,106 +421,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
}

void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
const MatMulTileOptions& gemm_tile = params_->tile_sizes;
const int64_t vectorization_factor = params_->supported_vec_size.epilogue;
// input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n]
mma_utils::checkConcreteStaticDim(c->axis(-2));
mma_utils::checkConcreteStaticDim(c->axis(-1));
const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as<int64_t>();
const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
tile_size_m == gemm_tile.cta_tile.m,
"Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.m,
", actual: ",
tile_size_m);
NVF_ERROR(
tile_size_n == gemm_tile.cta_tile.n,
"Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ",
gemm_tile.cta_tile.n,
", actual: ",
tile_size_n);
const int64_t tot_elements = tile_size_m * tile_size_n;
constexpr int64_t warp_size = 32l;
const int64_t tidx = warp_size;
const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n;
const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m;
// step-1, merge last 2 dims
c->merge(-2);
// [Mo, No, m*n]

// step-2, set vectorization to maximum
// We have fixed tidx, tidy, and tidz, so we need to make sure that the
// output tensor is divisible by tidx * tidy * tidz * vectorization_factor
NVF_ERROR(
tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0,
"Output tensor cannot be fully vectorized! tot_elements:",
tot_elements,
", tidx: ",
tidx,
", tidy: ",
tidy,
", tidz: ",
tidz,
", vectorization_factor: ",
vectorization_factor);
c->split(-1, vectorization_factor);
c->axis(-1)->parallelize(ParallelType::Vectorize);
// [Mo, No, m*n/vect, vect]

// step-3, Split out a warp for TIDx
c->split(-2, tidx);
c->axis(-2)->parallelize(ParallelType::TIDx);
// [Mo, No, m*n/vect/TIDx, TIDx, vect]

// step-4, Split out for TIDy and TIDz
// TIDy = cta_tile_n/warp_tile_n
// TIDz = cta_tile_m/warp_tile_m
c->split(-3, tidy);
c->axis(-3)->parallelize(ParallelType::TIDy);

c->split(-4, tidz);
c->axis(-4)->parallelize(ParallelType::TIDz);
// [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect]

for (TensorView* mma_result : mma_results_) {
// step-5, Parallel first 2 dims same as mma_result
scheduler_utils::parallelizeAllLike(
mma_result,
2,
{c},
{ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz});
}
// Block Schedule and Parallelize
blockTileTensors({c});
parallelizeBlocks({c});

// Apply mma common transformation
c->split(-2, getM(params_->mma_macro));
c->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
c->reorder({{-3, -2}});
c->merge(-4);

// [...., Mii, Nii] ->
// [..., Mii/16, Miioi(2), Miii(8), Nii/8, Niio(4), Niii(2)] ->
// [.., Mii/16, Miii(8), Niio(4), Nii/8, Miioi(2), Niii(2) ]
// [..., Mii/16 * Miii(8) * Niio(4), Nii/8, Miioi(2), Niii(2) ]
// Mii/16 * Miii(8) * Niio(4) is 128 for Hopper and this is parallelized as
// TIDx.
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(c->getLoopDomain());
c->setLoopDomain(s.as<IterDomain*>());
c->axis(-5)->parallelize(ParallelType::TIDy);
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// TODO: schedule epilogue by propagation backward from dc
if (!params_->use_smem_epilogue) {
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
auto* dc = d->definition()->input(0)->as<TensorView>();

// Block Schedule and Parallelize
blockTileTensors({dc, d});
parallelizeBlocks({dc, d});

// Apply mma common transformation
for (auto tv : {dc, d}) {
// [..., Mo, No, Mi, Ni]
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// [..., Mo, No, Mio, Mii, Nio, Nii]
// -> [..., Mo, No, Mio, Nio, Mii, Nii]
tv->reorder({{-3, -2}});
tv->merge(-4);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv->axis(-5)->parallelize(ParallelType::TIDy);
}
// Schedule the output TV and propagate it back to the outputs of the Mma
// op.
scheduleOutputTensor(d);
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d,
-1,
mma_results_,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

// We don't respect vectorization_factor as yet. We vectorize the
// inner-dim with extent 2.
// TODO: support vectorization_factor.
d->axis(-1)->parallelize(ParallelType::Vectorize);
}
scheduleFusionInputsForEpilogue();
} else {
constexpr int64_t stmatrix_tile_m = 16;
constexpr int64_t stmatrix_tile_n = 16;
Expand Down
73 changes: 72 additions & 1 deletion tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3218,7 +3218,7 @@ class HopperMatmulSchedulerTest
KernelExecutor ke;
ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
auto nvf_out = ke.run(inputs);
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-5, 1e-5));
EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2));
}

protected:
Expand Down Expand Up @@ -3297,6 +3297,77 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
tref = atMatmul(A.squeeze(), B.squeeze(), layout);
}

TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
if (use_smem_epilogue) {
GTEST_SKIP()
<< "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now";
}
const auto& [A, B] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
const auto& C = matmulAtInput2D(
layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K);
inputs = {A, B, C};

TensorView* tv0 = nullptr;
TensorView* tv1 = nullptr;
std::unordered_map<int64_t, int64_t> old2new;
int64_t k_axis = 0;

switch (layout) {
case MmaLayout::TT:
// Inner dims KN, order is MKN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {{-2, -1}, {-1, -2}};
k_axis = -2;
break;
case MmaLayout::TN:
// Inner dims KK, order is MNK
tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
old2new = {};
k_axis = -1;
break;
case MmaLayout::NT:
// Inner dims MN, order is KMN
tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
old2new = {{-3, -1}};
k_axis = -3;
break;
case MmaLayout::NN:
// Inner dims MK, order is NKM
tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
old2new = {{-1, -3}};
k_axis = -2;
break;
}
TensorView* tv2 = makeContigConcreteTensor({-1}, dtype);

fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis});

// Reorder the accumulator as [M, N, K]
tv3->reorder(old2new);
tv3->commitLeafToLogical();

auto* tv4 = maybeCastOp(DataType::Float, tv2);
auto* tv5 = biasEpilogue(tv3, tv4);
auto* tv6 = neg(tv5);
auto* tv7 = castOp(dtype, tv6);
fusion->addOutput(tv7);

tref = atBiasEpilogue(
atMatmul(A.squeeze(), B.squeeze(), layout),
C.to(data_type_to_aten(DataType::Float)))
.neg_()
.to(data_type_to_aten(DataType::Half));
}

INSTANTIATE_TEST_SUITE_P(
General,
HopperMatmulSchedulerTest,
Expand Down

0 comments on commit 8bf8997

Please sign in to comment.