Skip to content

Commit

Permalink
More test fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen committed Oct 7, 2024
1 parent ab83fb7 commit ff57be3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 71 deletions.
140 changes: 70 additions & 70 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -186,8 +186,8 @@ TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -245,8 +245,8 @@ TEST_P(MatmulTestWithLayout, AmpereProloguePointwise) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.sin().to(at::kFloat),
Expand Down Expand Up @@ -304,8 +304,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBFloat16) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -365,8 +365,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -447,8 +447,8 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -573,8 +573,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulRegCircularBuffer) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -863,8 +863,8 @@ TEST_F(MatmulTest, MatmulMatmulAmpere) {
fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams));

auto cg_outputs = fe.runFusion({t0, t1, t2});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// relaxed check for now, err accumulation is significant.
NVF_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1));
}
Expand Down Expand Up @@ -1243,8 +1243,8 @@ TEST_F(MatmulTest, MatmulSoftmaxMatmulAmpere) {
fe.compileFusion(&fusion, {t0, t1, t2}, LaunchParams(), matmul_cparams));

auto cg_outputs = fe.runFusion({t0, t1, t2});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
auto sg1 = at::_softmax(g1, -1, false);
auto gsg1 = sg1.matmul(t2.t().to(at::kFloat));
Expand Down Expand Up @@ -1292,8 +1292,8 @@ TEST_P(MatmulTestWithLayout, TuringMatmul) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -1439,8 +1439,8 @@ TEST_F(MatmulTest, AmpereMatmulTNCpAsync) {
fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams));

auto cg_outputs = fe.runFusion({t0, t1});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));

NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
Expand Down Expand Up @@ -1608,8 +1608,8 @@ TEST_F(MatmulTest, AmpereStridedBatchedMatmulTN) {
fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams));

auto cg_outputs = fe.runFusion({t0, t1});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// ref implementation:
auto ref_t0 = t0.permute({0, 2, 1, 3})
.contiguous()
Expand Down Expand Up @@ -1781,8 +1781,8 @@ TEST_F(MatmulTest, AmpereViewMatmulTN) {
fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams));

auto cg_outputs = fe.runFusion({t0, t1});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref =
at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat));

Expand Down Expand Up @@ -1964,8 +1964,8 @@ TEST_F(MatmulTest, AmpereMatmulTNSwizzled) {
FusionExecutor fe;
fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams);
auto cg_outputs = fe.runFusion({t0, t1});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));

NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
Expand Down Expand Up @@ -2020,8 +2020,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoad) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -2074,8 +2074,8 @@ TEST_P(MatmulTestWithLayout, TuringMatmulLargeLoad) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -2145,8 +2145,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck4warp) {
matmul_cparams));
EXPECT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
NVF_CHECK(
Expand Down Expand Up @@ -2223,8 +2223,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck8warp) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -2292,8 +2292,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulTileCheck6warp) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -2350,8 +2350,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulLargeLoadLargeK) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -2402,8 +2402,8 @@ TEST_P(MatmulTestWithLayout, AmpereSplitKLikeStridedBatchedMatmul) {
0,
fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({t0, t1});
auto tref = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout);
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
Expand Down Expand Up @@ -2500,8 +2500,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogue) {

// check bank conflicts
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// (0.001, 0.001) passed on local A100 but failed on CI A100
NVF_CHECK(
cg_outputs[0].allclose(tref, 0.01, 0.01),
Expand Down Expand Up @@ -2641,8 +2641,8 @@ TEST_F(MatmulTest, AmpereMatmulSmemEpiloguePromotionRequiredA100) {

// check bank conflicts
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// (0.001, 0.001) passed on local A100 but failed on CI A100
NVF_CHECK(
cg_outputs[0].allclose(tref, 0.01, 0.01),
Expand Down Expand Up @@ -2741,8 +2741,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueCast) {
tref = tref.to(at::kHalf);
// check bank conflicts
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// (0.001, 0.001) passed on local A100 but failed on CI A100
NVF_CHECK(
cg_outputs[0].allclose(tref, 0.01, 0.01),
Expand Down Expand Up @@ -2838,8 +2838,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueRelu) {

// check bank conflicts
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
// (0.001, 0.001) passed on local A100 but failed on CI A100
NVF_CHECK(
cg_outputs[0].allclose(tref, 0.01, 0.01),
Expand Down Expand Up @@ -2919,8 +2919,8 @@ TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitK_CUDA) {
7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
EXPECT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);

Expand Down Expand Up @@ -2987,8 +2987,8 @@ TEST_P(MatmulTestWithLayout, FusionAmpereMatmulSplitKBias_CUDA) {
7, 5, fe.compileFusion(&fusion, inputs));
EXPECT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
auto cg_outputs = fe.runFusion(inputs);
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto tref = atBiasEpilogue(
atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout),
aten_bias);
Expand Down Expand Up @@ -3052,8 +3052,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitK) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
7, 5, fe.compileFusion(&fusion, inputs));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto tref =
atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout);
Expand Down Expand Up @@ -3122,8 +3122,8 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulBatchSplitKBias) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
7, 5, fe.compileFusion(&fusion, inputs));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto tref = atBiasEpilogue(
atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout),
Expand Down Expand Up @@ -3187,8 +3187,8 @@ TEST_F(MatmulTest, ReproIssue1808) {
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
Expand Down Expand Up @@ -3339,8 +3339,8 @@ TEST_P(MatmulTestWithLayout, MisalignedVectorization) {
fe.compileFusion(
fusion.get(), inputs, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto outputs = fe.runFusion(inputs);

EXPECT_TRUE(outputs[0].allclose(tref, 0.001, 0.001));
Expand Down Expand Up @@ -3396,8 +3396,8 @@ TEST_F(MatmulTest, MultipleConsecutiveDims) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto tref = at::reshape(
at::linear(
Expand Down Expand Up @@ -3462,8 +3462,8 @@ TEST_F(MatmulTest, DISABLED_MultipleNonConsecutiveMDims) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto Apermuted = A.permute({{1, 2}}).reshape({M1 * M2, K});
auto tref = at::linear(Apermuted.to(at::kFloat), B.to(at::kFloat))
Expand Down Expand Up @@ -3528,8 +3528,8 @@ TEST_F(MatmulTest, DISABLED_MultipleNonConsecutiveNDims) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto Bpermuted = B.permute({{1, 2}}).reshape({N1 * N2, K});
auto tref = at::linear(A.to(at::kFloat), Bpermuted.to(at::kFloat))
Expand Down Expand Up @@ -3586,8 +3586,8 @@ TEST_F(MatmulTest, MultipleMDimsBatch) {
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8, 0, fe.compileFusion(&fusion, inputs, LaunchParams(), matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.compiledKernel()->kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
fe.compiledKernel()->kernel()));
auto cg_outputs = fe.runFusion(inputs);
auto tref =
at::matmul(A.to(at::kFloat), at::permute(B.to(at::kFloat), {0, 2, 1}));
Expand Down
Loading

0 comments on commit ff57be3

Please sign in to comment.