Skip to content

Commit

Permalink
Change some tests to use FusionExecutorCache.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 6, 2024
1 parent a4d3a12 commit d1e4655
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions tests/cpp/test_allocation_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,39 +107,33 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC4d) {
// A global->global copy kernel converting NCHW memory format into NHWC, with a
// 1d allocation domain in output.
TEST_F(AllocationDomainTest, NCHW4d_To_NHWC1d) {
Fusion fusion;
FusionGuard fg(&fusion);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeContigTensor(4);
fusion.addInput(tv0);
fusion->addInput(tv0);
auto tv1 = set(tv0);
fusion.addOutput(tv1);
fusion->addOutput(tv1);

// [N, C, H, W]
tv1->reorder({{1, -1}});
// [N, H, W, C]
tv1->flatten();
tv1->setAllocationDomain({tv1->axis(0)}, true);
// [N*H*W*C]
tv1->split(0, 128);
tv1->axis(1)->parallelize(ParallelType::TIDx);
tv1->axis(0)->parallelize(ParallelType::BIDx);
// [BIDx, TIDx]
tv1->setAllocationDomain(tv1->getLoopDomain(), true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

int n = 31, h = 64, w = 103, c = 21;

at::Tensor t0 = at::randn({n, c, h, w}, options);

FusionExecutor fe;
fe.compileFusion(&fusion, {t0});

auto cg_outputs = fe.runFusion({t0});
FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({t0});

ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast));
EXPECT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast));

testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__);
testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__);
}

// A global->global copy kernel converting NCHW memory format into NHWC, with a
Expand Down

0 comments on commit d1e4655

Please sign in to comment.