Skip to content

Commit

Permalink
Remove outdated checks from the matmul scheduler. (#2221)
Browse files Browse the repository at this point in the history
For #2199 

Broadcasts before Mma are optional. matmul_expr_eval still has problems
with this, but I'll file a separate issue for that.
  • Loading branch information
wujingyue authored May 9, 2024
1 parent 89e792e commit 02e327d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
27 changes: 1 addition & 26 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ inline std::optional<MmaMacro> getMmaOp(
return (use_small_n) ? MacroType::Ampere_16_8_16
: MacroType::Ampere_16_16_16;
default:
break;
return std::nullopt;
}
return std::nullopt;
}

//! A wrapper for core heuristics initialization.
Expand Down Expand Up @@ -267,30 +266,6 @@ std::string isMatmulFusionDefinitionSupported(
}
}

// MmaOp inputs/outputs dependencies check
// TODO: check to be removed when more rules are added to TV roles
// calculations
{
// Check the expected path between MmaOp input and fusion inputs
const auto areMmaOpInputDependeciesValid = [](const Val* val) {
if (val->definition()->isA<BroadcastOp>()) {
const auto& bcast_inputs = val->definition()->inputs();
// BroadcastOp has single input/output, not need to check other things
return bcast_inputs.front()->isFusionInput() ||
(dynamic_cast<LoadStoreOp*>(bcast_inputs.front()->definition()) !=
nullptr);
}
return false;
};

// MmaOp input is a result of broadcast op with input being fusion input
for (const auto* mma_in : mma_inputs) {
if (!areMmaOpInputDependeciesValid(mma_in)) {
return "MmaOp input has unsupported dependency";
}
}
}

return "";
}

Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,49 @@ INSTANTIATE_TEST_SUITE_P(
return os.str();
});

TEST_F(MatmulSchedulerTest, FusedMultiplySumOnly) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

constexpr int64_t M = 128, N = 256, K = 512;
TensorView* x = makeContigConcreteTensor({M, 1, K}, DataType::Half);
TensorView* y = makeContigConcreteTensor({1, N, K}, DataType::Half);
TensorView* z = fusedMultiplySum(x, y, {-1});

fusion->addInput(x);
fusion->addInput(y);
fusion->addOutput(z);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto x_ref = at::randn({M, 1, K}, options);
auto y_ref = at::randn({1, N, K}, options);
auto z_ref = atMatmul(x_ref, y_ref, MmaLayout::TN);

FusionExecutorCache executor_cache(std::move(fusion));

auto out_tensors = executor_cache.runFusionWithInputs({x_ref, y_ref});

NVF_CHECK(
!executor_cache.getMostRecentKernelRuntime()->isSegmented(),
"fusion got segmented, expected to match whole fusion with single segment");

NVF_CHECK(
isSchedulerInUse(
executor_cache.getMostRecentKernelRuntime(),
ScheduleHeuristic::Matmul),
"matmul scheduler was not used to handle prepared fusion");

testValidate(
executor_cache.fusion(),
out_tensors,
{x_ref, y_ref},
{z_ref},
__LINE__,
__FILE__);
}

// Matmul test that uses segmenter for 'C = A x B' fusion,
// for Ampere with strict ref check, hence single layout check
TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) {
Expand Down

0 comments on commit 02e327d

Please sign in to comment.