From 61a77e0a64d5bc446ba1c009f04a19204a28eab2 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:19:53 -0400 Subject: [PATCH] check shared memory predicate in matmul tests (#3102) Adding check to matmul tests since shared memory access in matmul kernels shouldn't be predicated. Helps https://github.com/NVIDIA/Fuser/pull/2339 --- tests/cpp/test_matmul.cpp | 76 +++++++++++++++++++++++++++++++++++---- tests/cpp/utils.h | 27 ++++++++++++++ 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 239843f52eb..7db0b85cd3b 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -132,6 +132,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -184,6 +186,8 @@ TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -241,6 +245,8 @@ TEST_P(MatmulTestWithLayout, AmpereProloguePointwise) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.sin().to(at::kFloat), @@ -298,6 +304,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBFloat16) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -357,6 +365,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -437,6 +447,8 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -561,6 +573,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulRegCircularBuffer) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -849,7 +863,8 @@ TEST_F(MatmulTest, MatmulMatmulAmpere) { fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); auto cg_outputs = fe.runFusion({t0, t1, t2}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // relaxed check for now, err accumulation is significant. NVF_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } @@ -1228,7 +1243,8 @@ TEST_F(MatmulTest, MatmulSoftmaxMatmulAmpere) { fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams)); auto cg_outputs = fe.runFusion({t0, t1, t2}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); auto sg1 = at::_softmax(g1, -1, false); auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); @@ -1276,6 +1292,8 @@ TEST_P(MatmulTestWithLayout, TuringMatmul) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -1421,7 +1439,8 @@ TEST_F(MatmulTest, AmpereMatmulTNCpAsync) { fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); auto cg_outputs = fe.runFusion({t0, t1}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); @@ -1589,7 +1608,8 @@ TEST_F(MatmulTest, AmpereStridedBatchedMatmulTN) { fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); auto cg_outputs = fe.runFusion({t0, t1}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // ref implementation: auto ref_t0 = t0.permute({0, 2, 1, 3}) .contiguous() @@ -1761,7 +1781,8 @@ TEST_F(MatmulTest, AmpereViewMatmulTN) { fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); auto cg_outputs = fe.runFusion({t0, t1}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -1943,7 +1964,8 @@ TEST_F(MatmulTest, AmpereMatmulTNSwizzled) { FusionExecutor fe; fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); auto cg_outputs = fe.runFusion({t0, t1}); - + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); @@ -1998,6 +2020,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoad) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2050,6 +2074,8 @@ TEST_P(MatmulTestWithLayout, TuringMatmulLargeLoad) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2119,6 +2145,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck4warp) { matmul_cparams)); EXPECT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK( @@ -2195,6 +2223,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck8warp) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2262,6 +2292,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck6warp) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2318,6 +2350,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoadLargeK) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2368,6 +2402,8 @@ TEST_P(MatmulTestWithLayout, AmpereSplitKLikeStridedBatchedMatmul) { 0, fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); @@ -2464,6 +2500,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogue) { // check bank conflicts ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // (0.001, 0.001) passed on local A100 but failed on CI A100 NVF_CHECK( cg_outputs[0].allclose(tref, 0.01, 0.01), @@ -2602,6 +2640,8 @@ TEST_F(MatmulTest, AmpereMatmulSmemEpiloguePromotionRequiredA100) { // check bank conflicts ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // (0.001, 0.001) passed on local A100 but failed on CI A100 NVF_CHECK( cg_outputs[0].allclose(tref, 0.01, 0.01), @@ -2700,6 +2740,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueCast) { tref = tref.to(at::kHalf); // check bank conflicts ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // (0.001, 0.001) passed on local A100 but failed on CI A100 NVF_CHECK( cg_outputs[0].allclose(tref, 0.01, 0.01), @@ -2795,6 +2837,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueRelu) { // check bank conflicts ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); // (0.001, 0.001) passed on local A100 but failed on CI A100 NVF_CHECK( cg_outputs[0].allclose(tref, 0.01, 0.01), @@ -2874,6 +2918,8 @@ TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitK_CUDA) { 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); EXPECT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2940,6 +2986,8 @@ TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitKBias_CUDA) { 7, 5, fe.compileFusion(&fusion, inputs)); EXPECT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto tref = atBiasEpilogue( atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), aten_bias); @@ -3003,6 +3051,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitK) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 7, 5, fe.compileFusion(&fusion, inputs)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto tref = atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout); @@ -3071,6 +3121,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitKBias) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 7, 5, fe.compileFusion(&fusion, inputs)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto tref = atBiasEpilogue( atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), @@ -3134,6 +3186,8 @@ TEST_F(MatmulTest, ReproIssue1808) { LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -3284,6 +3338,8 @@ TEST_P(MatmulTestWithLayout, MisalignedVectorization) { fe.compileFusion( fusion.get(), inputs, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto outputs = fe.runFusion(inputs); EXPECT_TRUE(outputs[0].allclose(tref, 0.001, 0.001)); @@ -3339,6 +3395,8 @@ TEST_F(MatmulTest, MultipleConsecutiveDims) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto tref = at::reshape( at::linear( @@ -3403,6 +3461,8 @@ TEST_F(MatmulTest, DISABLED_MultipleNonConsecutiveMDims) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto Apermuted = A.permute({{1, 2}}).reshape({M1 * M2, K}); auto tref = at::linear(Apermuted.to(at::kFloat), B.to(at::kFloat)) @@ -3467,6 +3527,8 @@ TEST_F(MatmulTest, DISABLED_MultipleNonConsecutiveNDims) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto Bpermuted = B.permute({{1, 2}}).reshape({N1 * N2, K}); auto tref = at::linear(A.to(at::kFloat), Bpermuted.to(at::kFloat)) @@ -3523,6 +3585,8 @@ TEST_F(MatmulTest, MultipleMDimsBatch) { NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams)); ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); auto cg_outputs = fe.runFusion(inputs); auto tref = at::matmul(A.to(at::kFloat), at::permute(B.to(at::kFloat), {0, 2, 1})); diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index ed525de9a66..a8fcd5d4cd7 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -169,6 +169,33 @@ class PredicatedChecker : public kir::IrVisitor { return isPredicated(tv->name(), kernel); } + static bool isPredicatedByIfThenElse( + StmtNameType tv_name, + kir::Kernel* kernel) { + PredicatedChecker checker(tv_name, kernel->topLevelExprs()); + return checker.predicated_ite_; + } + + // If CpAsync from gmem to smem, then loaded from smem to registers using + // ldmatrix, then it is used in mma and should not use if-then-else predicate. + // If just CpAsync from gmem to smem, without further copy to register, then + // it is not used in mma and can use if-then-else predicate. + static bool isCpAsyncMmaPredicatedByIfThenElse(kir::Kernel* kernel) { + for (auto tv : kernel->allTvs()) { + if (tv->definition() != nullptr && + ir_utils::isCpAsyncOp(tv->definition())) { + const auto& consumers = ir_utils::consumerTvsOf(tv); + if (std::any_of( + consumers.begin(), consumers.end(), [&](TensorView* tv) { + return ir_utils::isLdMatrixOp(tv->definition()); + })) { + return isPredicatedByIfThenElse(tv->name(), kernel); + } + } + } + return false; + } + private: PredicatedChecker() = delete;