From 61ffac91e715b0f57fd658fb815e2f79a14685cd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 12 Nov 2024 19:05:01 -0800 Subject: [PATCH] Change tests to call runPass directly. (#3398) --- .../allocation_order_inference.cpp | 4 +-- .../allocation_order_inference.h | 9 ------ tests/cpp/test_allocation_order_inference.cpp | 28 ++++++++++++------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index c01f07d7b15..25a706f56d2 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -214,8 +214,6 @@ void mapAllocationDomain( } } -} // namespace - // Note [ Allocation Order Propagation ] // // The propagation tries to populate allocation domain from srcs to dsts. @@ -336,6 +334,8 @@ void inferenceAllocationOrder( } } +} // namespace + void AllocationDomainPass::runPass(Fusion* fusion) { // mark input TensorViews as propagation sources auto input_tvs = ir_utils::filterByType(fusion->inputs()); diff --git a/csrc/preseg_passes/allocation_order_inference.h b/csrc/preseg_passes/allocation_order_inference.h index d47b80432d7..d7ac7bddcac 100644 --- a/csrc/preseg_passes/allocation_order_inference.h +++ b/csrc/preseg_passes/allocation_order_inference.h @@ -14,15 +14,6 @@ namespace nvfuser::preseg_passes { -// Propagate allocation domain from srcs to dsts. -// The pass update allocation domain on dsts tensor views. -// -// See details in Note [ Allocation Order Propagation ] -void inferenceAllocationOrder( - Fusion* fusion, - const std::vector& srcs, - const std::vector& dsts); - // Realize allocation order propagation on fusion inputs to optimize allocation // domain of output tensor. This optimization pass currently only applies to // fusion outputs, but not intermediate tensors. diff --git a/tests/cpp/test_allocation_order_inference.cpp b/tests/cpp/test_allocation_order_inference.cpp index c24d679bfbb..14b7910ff48 100644 --- a/tests/cpp/test_allocation_order_inference.cpp +++ b/tests/cpp/test_allocation_order_inference.cpp @@ -51,7 +51,8 @@ TEST_F(AllocationOrderInferenceTest, BroadcastOpPropagation) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT( getAllocationDomainPermutation(tv2), ElementsAre(0, 3, 5, 7, 1, 4, 6, 2)); EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); @@ -71,7 +72,8 @@ TEST_F(AllocationOrderInferenceTest, UnaryOpPropagation) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv1}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv1), ElementsAre(0, 2, 3, 1)); } @@ -101,7 +103,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationOneTV) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv2, tv3, tv6, tv7}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(0, 2, 3, 1)); EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); EXPECT_THAT(getAllocationDomainPermutation(tv6), ElementsAre(0, 2, 3, 1)); @@ -131,7 +134,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationTwoTV) { tv1->axis(1), tv1->axis(0), tv1->axis(2), tv1->axis(3)}; tv1->setAllocationDomain(tv1_format, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 0, 2, 3)); EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 0, 2, 3)); } @@ -157,7 +161,8 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagationWithBroadcast) { tv0->axis(3), tv0->axis(2), tv0->axis(0), tv0->axis(1)}; tv0->setAllocationDomain(tv0_alloc, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(0, 3, 2, 1)); } @@ -186,7 +191,8 @@ TEST_F(AllocationOrderInferenceTest, TensorFactoryBinaryOpPropagation) { std::vector tv1_c_last = {tv1->axis(0), tv1->axis(1)}; tv1->setAllocationDomain(tv1_c_last, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv2, tv3}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 0)); EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 0)); } @@ -214,7 +220,8 @@ TEST_F(AllocationOrderInferenceTest, TensorEmptyAllocationOrderPropagation) { std::vector tv0_c_last = {tv0->axis(1), tv0->axis(0)}; tv0->setAllocationDomain(tv0_c_last, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv4}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(1, 0)); } @@ -244,7 +251,8 @@ TEST_F(AllocationOrderInferenceTest, TernaryOpPropagation) { tv2->axis(0), tv2->axis(2), tv2->axis(3), tv2->axis(1)}; tv2->setAllocationDomain(tv2_nhwc, true); - preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1, tv2}, {tv3, tv4}); + preseg_passes::OptimizationPass::runPass( + &fusion); EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(0, 2, 3, 1)); } @@ -281,8 +289,8 @@ TEST_F(AllocationOrderInferenceTest, ReductionOpPropagation) { auto tv5 = broadcast(tv3, {true, false, false, true}); fusion.addOutput(tv5); - preseg_passes::inferenceAllocationOrder( - &fusion, {tv0, tv1}, {tv2, tv3, tv4, tv5}); + preseg_passes::OptimizationPass::runPass( + &fusion); #if true // permutation here is strange because in propagation we are preserving // reduction iter domain in its position in logical domain See issue: