From 2e293a32017d11aa872a7ec1a362e7fc9640f156 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 16 Jul 2024 09:15:52 -0700 Subject: [PATCH] Add test for forward propagation in TMA analysis (#2571) --- tests/cpp/test_memory.cpp | 56 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 6e61fdac8de..877fd0e51ec 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -1140,7 +1140,7 @@ TEST_F(TMAIndexingTest, DefineBoxByRotation3) { ::testing::HasSubstr("must be divisible by 23"))); } -TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain) { +TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain1) { Fusion fusion; FusionGuard fg(&fusion); @@ -1183,6 +1183,60 @@ TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain2) { + Fusion fusion; + FusionGuard fg(&fusion); + + const DataType dtype = DataType::Float; + + auto tv0 = makeContigTensor(6, dtype); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + tv1->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + + // Schedule like this: + // 0 1 2 3 4 5 + // \ \ / / \ / + // \ 6 / 7 + // \ / / + // 8 / + // \ / + // 9 + // where 1 and 5 are bulk IDs. This way, [merge 1, 2 -> 6] is a "striding + // split", and [merge 0, 6 -> 8] and [merge 4, 5 -> 7] are "boxing splits". + tv0->merge(1); + tv0->merge(0); + tv0->merge(-2); + tv0->merge(0); + tv0->setAllocationDomain(tv0->getLoopDomain(), true); + + for (auto tv : {tv1, tv2}) { + tv->reorder({{1, -2}}); + tv->merge(-2); + tv->flatten(0, -2); + tv->axis(0)->parallelize(ParallelType::BIDx); + } + tv1->axis(1)->parallelize(ParallelType::Bulk); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({2, 3, 5, 7, 11, 32}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}, {}, matmul_cparams); + + EXPECT_EQ(TMADimChecker::getDim(fe.kernel()), 3); + TMAPredicateChecker::checkPredicate(fe.kernel(), 1); + + auto cg_outputs = fe.runFusion({t0}); + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + // TODO: improve validation of TMA, and add tests for invalid cases. // TODO: test that broadcasting IterDomains are correctly handled by TMA.