From 02e327dbc643d82ddfaf54dd5f8dc045b354e730 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 9 May 2024 09:08:15 -0700 Subject: [PATCH] Remove outdated checks from the matmul scheduler. (#2221) For #2199 Broadcasts before Mma are optional. matmul_expr_eval still has problems with this, but I'll file a separate issue for that. --- csrc/scheduler/matmul_utils.cpp | 27 +----------------- tests/cpp/test_matmul_scheduler.cpp | 43 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index f967d534618..b2287c9cf2f 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -62,9 +62,8 @@ inline std::optional getMmaOp( return (use_small_n) ? MacroType::Ampere_16_8_16 : MacroType::Ampere_16_16_16; default: - break; + return std::nullopt; } - return std::nullopt; } //! A wrapper for core heuristics initialization. @@ -267,30 +266,6 @@ std::string isMatmulFusionDefinitionSupported( } } - // MmaOp inputs/outputs dependencies check - // TODO: check to be removed when more rules are added to TV roles - // calculations - { - // Check the expected path between MmaOp input and fusion inputs - const auto areMmaOpInputDependeciesValid = [](const Val* val) { - if (val->definition()->isA()) { - const auto& bcast_inputs = val->definition()->inputs(); - // BroadcastOp has single input/output, not need to check other things - return bcast_inputs.front()->isFusionInput() || - (dynamic_cast(bcast_inputs.front()->definition()) != - nullptr); - } - return false; - }; - - // MmaOp input is a result of broadcast op with input being fusion input - for (const auto* mma_in : mma_inputs) { - if (!areMmaOpInputDependeciesValid(mma_in)) { - return "MmaOp input has unsupported dependency"; - } - } - } - return ""; } diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 71be65b98c0..75757a2667c 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -1055,6 +1055,49 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +TEST_F(MatmulSchedulerTest, FusedMultiplySumOnly) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t M = 128, N = 256, K = 512; + TensorView* x = makeContigConcreteTensor({M, 1, K}, DataType::Half); + TensorView* y = makeContigConcreteTensor({1, N, K}, DataType::Half); + TensorView* z = fusedMultiplySum(x, y, {-1}); + + fusion->addInput(x); + fusion->addInput(y); + fusion->addOutput(z); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto x_ref = at::randn({M, 1, K}, options); + auto y_ref = at::randn({1, N, K}, options); + auto z_ref = atMatmul(x_ref, y_ref, MmaLayout::TN); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto out_tensors = executor_cache.runFusionWithInputs({x_ref, y_ref}); + + NVF_CHECK( + !executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "fusion got segmented, expected to match whole fusion with single segment"); + + NVF_CHECK( + isSchedulerInUse( + executor_cache.getMostRecentKernelRuntime(), + ScheduleHeuristic::Matmul), + "matmul scheduler was not used to handle prepared fusion"); + + testValidate( + executor_cache.fusion(), + out_tensors, + {x_ref, y_ref}, + {z_ref}, + __LINE__, + __FILE__); +} + // Matmul test that uses segmenter for 'C = A x B' fusion, // for Ampere with strict ref check, hence single layout check TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) {