diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index fbb95d46df2..509d3213380 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -421,106 +421,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { } void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) { - const MatMulTileOptions& gemm_tile = params_->tile_sizes; - const int64_t vectorization_factor = params_->supported_vec_size.epilogue; - // input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n] - mma_utils::checkConcreteStaticDim(c->axis(-2)); - mma_utils::checkConcreteStaticDim(c->axis(-1)); - const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as(); - const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as(); - NVF_ERROR( - tile_size_m == gemm_tile.cta_tile.m, - "Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ", - gemm_tile.cta_tile.m, - ", actual: ", - tile_size_m); - NVF_ERROR( - tile_size_n == gemm_tile.cta_tile.n, - "Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ", - gemm_tile.cta_tile.n, - ", actual: ", - tile_size_n); - const int64_t tot_elements = tile_size_m * tile_size_n; - constexpr int64_t warp_size = 32l; - const int64_t tidx = warp_size; - const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n; - const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m; - // step-1, merge last 2 dims - c->merge(-2); - // [Mo, No, m*n] - - // step-2, set vectorization to maximum - // We have fixed tidx, tidy, and tidz, so we need to make sure that the - // output tensor is divisible by tidx * tidy * tidz * vectorization_factor - NVF_ERROR( - tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0, - "Output tensor cannot be fully vectorized! tot_elements:", - tot_elements, - ", tidx: ", - tidx, - ", tidy: ", - tidy, - ", tidz: ", - tidz, - ", vectorization_factor: ", - vectorization_factor); - c->split(-1, vectorization_factor); - c->axis(-1)->parallelize(ParallelType::Vectorize); - // [Mo, No, m*n/vect, vect] - - // step-3, Split out a warp for TIDx - c->split(-2, tidx); - c->axis(-2)->parallelize(ParallelType::TIDx); - // [Mo, No, m*n/vect/TIDx, TIDx, vect] - - // step-4, Split out for TIDy and TIDz - // TIDy = cta_tile_n/warp_tile_n - // TIDz = cta_tile_m/warp_tile_m - c->split(-3, tidy); - c->axis(-3)->parallelize(ParallelType::TIDy); - - c->split(-4, tidz); - c->axis(-4)->parallelize(ParallelType::TIDz); - // [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect] - - for (TensorView* mma_result : mma_results_) { - // step-5, Parallel first 2 dims same as mma_result - scheduler_utils::parallelizeAllLike( - mma_result, - 2, - {c}, - {ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz}); - } + // Block Schedule and Parallelize + blockTileTensors({c}); + parallelizeBlocks({c}); + + // Apply mma common transformation + c->split(-2, getM(params_->mma_macro)); + c->split(-1, getN(params_->mma_macro)); + // [..., Mo, No, Mio, Mii, Nio, Nii] + // -> [..., Mo, No, Mio, Nio, Mii, Nii] + c->reorder({{-3, -2}}); + c->merge(-4); + + // [...., Mii, Nii] -> + // [..., Mii/16, Miioi(2), Miii(8), Nii/8, Niio(4), Niii(2)] -> + // [.., Mii/16, Miii(8), Niio(4), Nii/8, Miioi(2), Niii(2) ] + // [..., Mii/16 * Miii(8) * Niio(4), Nii/8, Miioi(2), Niii(2) ] + // Mii/16 * Miii(8) * Niio(4) is 128 for Hopper and this is parallelized as + // TIDx. + auto s = + mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(c->getLoopDomain()); + c->setLoopDomain(s.as()); + c->axis(-5)->parallelize(ParallelType::TIDy); } void HopperMultipleMatmulScheduler::scheduleEpilogue() { - // TODO: schedule epilogue by propagation backward from dc if (!params_->use_smem_epilogue) { for (Val* dv : fusion_->outputs()) { auto* d = dv->as(); NVF_ERROR(d->definition() && d->definition()->isA()); - auto* dc = d->definition()->input(0)->as(); - - // Block Schedule and Parallelize - blockTileTensors({dc, d}); - parallelizeBlocks({dc, d}); - // Apply mma common transformation - for (auto tv : {dc, d}) { - // [..., Mo, No, Mi, Ni] - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, getN(params_->mma_macro)); - // [..., Mo, No, Mio, Mii, Nio, Nii] - // -> [..., Mo, No, Mio, Nio, Mii, Nii] - tv->reorder({{-3, -2}}); - tv->merge(-4); - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv->getLoopDomain()); - tv->setLoopDomain(s.as()); - tv->axis(-5)->parallelize(ParallelType::TIDy); - } + // Schedule the output TV and propagate it back to the outputs of the Mma + // op. + scheduleOutputTensor(d); + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + d, + -1, + mma_results_, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); + + // We don't respect vectorization_factor as yet. We vectorize the + // inner-dim with extent 2. + // TODO: support vectorization_factor. d->axis(-1)->parallelize(ParallelType::Vectorize); } + scheduleFusionInputsForEpilogue(); } else { constexpr int64_t stmatrix_tile_m = 16; constexpr int64_t stmatrix_tile_n = 16; diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index fa80a096dce..39a813893db 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3218,7 +3218,7 @@ class HopperMatmulSchedulerTest KernelExecutor ke; ke.compile(fusion, inputs, LaunchParams(), matmul_cparams); auto nvf_out = ke.run(inputs); - EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-5, 1e-5)); + EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2)); } protected: @@ -3297,6 +3297,77 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) { tref = atMatmul(A.squeeze(), B.squeeze(), layout); } +TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) { + if (use_smem_epilogue) { + GTEST_SKIP() + << "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now"; + } + const auto& [A, B] = + matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); + const auto& C = matmulAtInput2D( + layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K); + inputs = {A, B, C}; + + TensorView* tv0 = nullptr; + TensorView* tv1 = nullptr; + std::unordered_map old2new; + int64_t k_axis = 0; + + switch (layout) { + case MmaLayout::TT: + // Inner dims KN, order is MKN + tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); + old2new = {{-2, -1}, {-1, -2}}; + k_axis = -2; + break; + case MmaLayout::TN: + // Inner dims KK, order is MNK + tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); + tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); + old2new = {}; + k_axis = -1; + break; + case MmaLayout::NT: + // Inner dims MN, order is KMN + tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); + old2new = {{-3, -1}}; + k_axis = -3; + break; + case MmaLayout::NN: + // Inner dims MK, order is NKM + tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); + tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); + old2new = {{-1, -3}}; + k_axis = -2; + break; + } + TensorView* tv2 = makeContigConcreteTensor({-1}, dtype); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis}); + + // Reorder the accumulator as [M, N, K] + tv3->reorder(old2new); + tv3->commitLeafToLogical(); + + auto* tv4 = maybeCastOp(DataType::Float, tv2); + auto* tv5 = biasEpilogue(tv3, tv4); + auto* tv6 = neg(tv5); + auto* tv7 = castOp(dtype, tv6); + fusion->addOutput(tv7); + + tref = atBiasEpilogue( + atMatmul(A.squeeze(), B.squeeze(), layout), + C.to(data_type_to_aten(DataType::Float))) + .neg_() + .to(data_type_to_aten(DataType::Half)); +} + INSTANTIATE_TEST_SUITE_P( General, HopperMatmulSchedulerTest,