Skip to content

Commit

Permalink
Change tests to call runPass directly. (#3398)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Nov 13, 2024
1 parent 6b937c4 commit 61ffac9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
4 changes: 2 additions & 2 deletions csrc/preseg_passes/allocation_order_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ void mapAllocationDomain(
}
}

} // namespace

// Note [ Allocation Order Propagation ]
//
// The propagation tries to populate allocation domain from srcs to dsts.
Expand Down Expand Up @@ -336,6 +334,8 @@ void inferenceAllocationOrder(
}
}

} // namespace

void AllocationDomainPass::runPass(Fusion* fusion) {
// mark input TensorViews as propagation sources
auto input_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
Expand Down
9 changes: 0 additions & 9 deletions csrc/preseg_passes/allocation_order_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@

namespace nvfuser::preseg_passes {

// Propagate allocation domain from srcs to dsts.
// The pass update allocation domain on dsts tensor views.
//
// See details in Note [ Allocation Order Propagation ]
void inferenceAllocationOrder(
Fusion* fusion,
const std::vector<TensorView*>& srcs,
const std::vector<TensorView*>& dsts);

// Realize allocation order propagation on fusion inputs to optimize allocation
// domain of output tensor. This optimization pass currently only applies to
// fusion outputs, but not intermediate tensors.
Expand Down
28 changes: 18 additions & 10 deletions tests/cpp/test_allocation_order_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ TEST_F(AllocationOrderInferenceTest, BroadcastOpPropagation) {
tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)};
tv0->setAllocationDomain(tv0_nhwc, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(
getAllocationDomainPermutation(tv2), ElementsAre(0, 3, 5, 7, 1, 4, 6, 2));
EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1));
Expand All @@ -71,7 +72,8 @@ TEST_F(AllocationOrderInferenceTest, UnaryOpPropagation) {
tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)};
tv0->setAllocationDomain(tv0_nhwc, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv1});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv1), ElementsAre(0, 2, 3, 1));
}

Expand Down Expand Up @@ -101,7 +103,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationOneTV) {
tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)};
tv0->setAllocationDomain(tv0_nhwc, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv2, tv3, tv6, tv7});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(0, 2, 3, 1));
EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1));
EXPECT_THAT(getAllocationDomainPermutation(tv6), ElementsAre(0, 2, 3, 1));
Expand Down Expand Up @@ -131,7 +134,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationTwoTV) {
tv1->axis(1), tv1->axis(0), tv1->axis(2), tv1->axis(3)};
tv1->setAllocationDomain(tv1_format, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 0, 2, 3));
EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 0, 2, 3));
}
Expand All @@ -157,7 +161,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationWithBroadcast) {
tv0->axis(3), tv0->axis(2), tv0->axis(0), tv0->axis(1)};
tv0->setAllocationDomain(tv0_alloc, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(0, 3, 2, 1));
}

Expand Down Expand Up @@ -186,7 +191,8 @@ TEST_F(AllocationOrderInferenceTest, TensorFactoryBinaryOpPropagation) {
std::vector<IterDomain*> tv1_c_last = {tv1->axis(0), tv1->axis(1)};
tv1->setAllocationDomain(tv1_c_last, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv2, tv3});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 0));
EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 0));
}
Expand Down Expand Up @@ -214,7 +220,8 @@ TEST_F(AllocationOrderInferenceTest, TensorEmptyAllocationOrderPropagation) {
std::vector<IterDomain*> tv0_c_last = {tv0->axis(1), tv0->axis(0)};
tv0->setAllocationDomain(tv0_c_last, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv4});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(1, 0));
}

Expand Down Expand Up @@ -244,7 +251,8 @@ TEST_F(AllocationOrderInferenceTest, TernaryOpPropagation) {
tv2->axis(0), tv2->axis(2), tv2->axis(3), tv2->axis(1)};
tv2->setAllocationDomain(tv2_nhwc, true);

preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1, tv2}, {tv3, tv4});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1));
EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(0, 2, 3, 1));
}
Expand Down Expand Up @@ -281,8 +289,8 @@ TEST_F(AllocationOrderInferenceTest, ReductionOpPropagation) {
auto tv5 = broadcast(tv3, {true, false, false, true});
fusion.addOutput(tv5);

preseg_passes::inferenceAllocationOrder(
&fusion, {tv0, tv1}, {tv2, tv3, tv4, tv5});
preseg_passes::OptimizationPass<preseg_passes::AllocationDomainPass>::runPass(
&fusion);
#if true
// permutation here is strange because in propagation we are preserving
// reduction iter domain in its position in logical domain See issue:
Expand Down

0 comments on commit 61ffac9

Please sign in to comment.