Skip to content

Commit

Permalink
Add test for forward propagation in TMA analysis (NVIDIA#2571)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 16, 2024
1 parent 77e0317 commit 2e293a3
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ TEST_F(TMAIndexingTest, DefineBoxByRotation3) {
::testing::HasSubstr("must be divisible by 23")));
}

TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain) {
TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain1) {
Fusion fusion;
FusionGuard fg(&fusion);

Expand Down Expand Up @@ -1183,6 +1183,60 @@ TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain) {
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

TEST_F(TMAIndexingTest, NonTrivialGmemAllocationDomain2) {
Fusion fusion;
FusionGuard fg(&fusion);

const DataType dtype = DataType::Float;

auto tv0 = makeContigTensor(6, dtype);
fusion.addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion.addOutput(tv2);

tv1->setMemoryType(MemoryType::Shared);
tv1->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Schedule like this:
// 0 1 2 3 4 5
// \ \ / / \ /
// \ 6 / 7
// \ / /
// 8 /
// \ /
// 9
// where 1 and 5 are bulk IDs. This way, [merge 1, 2 -> 6] is a "striding
// split", and [merge 0, 6 -> 8] and [merge 4, 5 -> 7] are "boxing splits".
tv0->merge(1);
tv0->merge(0);
tv0->merge(-2);
tv0->merge(0);
tv0->setAllocationDomain(tv0->getLoopDomain(), true);

for (auto tv : {tv1, tv2}) {
tv->reorder({{1, -2}});
tv->merge(-2);
tv->flatten(0, -2);
tv->axis(0)->parallelize(ParallelType::BIDx);
}
tv1->axis(1)->parallelize(ParallelType::Bulk);
tv2->axis(1)->parallelize(ParallelType::TIDx);

auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
auto t0 = at::randn({2, 3, 5, 7, 11, 32}, options);
FusionExecutor fe;
fe.compileFusion(&fusion, {t0}, {}, matmul_cparams);

EXPECT_EQ(TMADimChecker::getDim(fe.kernel()), 3);
TMAPredicateChecker::checkPredicate(fe.kernel(), 1);

auto cg_outputs = fe.runFusion({t0});
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

// TODO: improve validation of TMA, and add tests for invalid cases.
// TODO: test that broadcasting IterDomains are correctly handled by TMA.

Expand Down

0 comments on commit 2e293a3

Please sign in to comment.