From 9bc3ecc2e742859a98472b09b55357b865a56228 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 20 Nov 2024 08:50:53 -0800 Subject: [PATCH] Return a bool indicating if all nodes are visited (#3452) Just a mechanical change. This makes it a little cumbersome when `require_all_to_visited` is not used, but sometimes we also need to check if all nodes were visited. --- csrc/bfs.h | 7 +-- csrc/device_lower/analysis/tma.cpp | 7 +-- csrc/device_lower/pass/index.cpp | 2 +- csrc/device_lower/utils.cpp | 20 +------- csrc/device_lower/utils.h | 6 --- csrc/device_lower/validation.cpp | 7 +-- csrc/id_model/circular_buffer_indexing.cpp | 1 + csrc/id_model/indexing_traversal.cpp | 1 + csrc/id_model/indexing_traversal.h | 2 +- csrc/ir/nodes.cpp | 7 +-- csrc/ir/utils.cpp | 24 ++++++++- csrc/ir/utils.h | 6 +++ csrc/iter_visitor.cpp | 5 +- csrc/iter_visitor.h | 2 +- csrc/kernel_ir.cpp | 1 + .../scheduler/tools/loop_domain_scheduler.cpp | 33 +++++++------ csrc/val_graph_visitor.cpp | 11 +++-- csrc/val_graph_visitor.h | 4 +- tests/cpp/test_bfs.cpp | 49 +++++++++++-------- tests/cpp/test_gpu3.cpp | 17 ++++--- tests/cpp/test_loop_domain_scheduling.cpp | 8 +-- 21 files changed, 123 insertions(+), 97 deletions(-) diff --git a/csrc/bfs.h b/csrc/bfs.h index b5a02067ad2..cef047f6077 100644 --- a/csrc/bfs.h +++ b/csrc/bfs.h @@ -199,9 +199,10 @@ class BFS { } } - // Find the shortest path from the from_ to to_. This + // Find the shortest path from the from_ to to_. A boolean value + // indicating if all nodes are visited is also returned. This // must be only used once traversal is completed. - virtual ExprPath getShortestExprPath() { + virtual std::pair getShortestExprPath() { NVF_ERROR( !require_all_to_visited_ || allToNodesVisited(), "Traveral is either not done or failed"); @@ -316,7 +317,7 @@ class BFS { VectorOfUniqueEntries> unique_path( path.rbegin(), path.rend()); - return unique_path.vector(); + return std::make_pair(unique_path.vector(), allToNodesVisited()); } // Check if a node is ready to visit. If yes, return the direction diff --git a/csrc/device_lower/analysis/tma.cpp b/csrc/device_lower/analysis/tma.cpp index 05a2991814a..eb2463b3923 100644 --- a/csrc/device_lower/analysis/tma.cpp +++ b/csrc/device_lower/analysis/tma.cpp @@ -1025,9 +1025,10 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) { std::list> exprs = [&]() { ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph(); auto exprs_vec = ValGraphBFS::getExprsBetween( - id_graph, - id_graph.toGroups(consumer_tv->getLoopDomain()), - id_graph.toGroups(gmem_tv->getMaybeAllocationDomain())); + id_graph, + id_graph.toGroups(consumer_tv->getLoopDomain()), + id_graph.toGroups(gmem_tv->getMaybeAllocationDomain())) + .first; return std::list(exprs_vec.begin(), exprs_vec.end()); }(); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 1ac49f5f2b4..1b7deccf303 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1832,7 +1832,7 @@ ValGroup getInnerMmaLoopGroup(TensorView* tv, const MmaOp* mma) { ValGroup inner = alloc_domain.back(); auto exprs = - ValGraphBFS::getExprsBetween(id_graph, loop_domain, alloc_domain); + ValGraphBFS::getExprsBetween(id_graph, loop_domain, alloc_domain).first; while (!exprs.empty()) { auto [expr, direction] = exprs.back(); exprs.pop_back(); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index f77a7520e9a..5298bfcb88f 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -280,24 +280,6 @@ std::vector getTvs(const std::vector& vals) { return tvs; } -TensorView* getTvOutput(const Expr* expr) { - for (auto out : expr->outputs()) { - if (auto tv = getTv(out)) { - return tv; - } - } - return nullptr; -} - -TensorView* getTvInput(const Expr* expr) { - for (auto inp : expr->inputs()) { - if (auto tv = getTv(inp)) { - return tv; - } - } - return nullptr; -} - bool isScalarOp(const Expr* expr) { for (auto out : expr->outputs()) { if (!out->isScalar()) { @@ -1923,7 +1905,7 @@ Val* proveLinearAndGetStride( // Propagate from linear_g to domain. Use frontier to keep track of the // how linear_g lives in the current propagation front. Projection frontier = linear_g; - auto path = ValGraphBFS::getExprsBetween(id_graph, domain, {linear_g}); + auto path = ValGraphBFS::getExprsBetween(id_graph, domain, {linear_g}).first; while (!path.empty()) { const auto& [eg, direction] = path.back(); path.pop_back(); diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 2f53e7ed0ae..fa62d3b3f76 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -97,12 +97,6 @@ bool isTV(const Val* const); // Returns if Expr is a TensorView or TensorIndex Expr. NVF_API bool isTvOp(const Expr*); -// Returns the first output of Expr that is a TensorView -NVF_API TensorView* getTvOutput(const Expr*); - -// Returns the first input of Expr that is a TensorView -TensorView* getTvInput(const Expr*); - //! Returns the iterdomain that maps to the thread dimension grouped //! to warps. Returns nullopt if the reduction is not to be lowered to //! a warp reduction. diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index cc1b2dec53a..1afa860b0d2 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -415,9 +415,10 @@ class VectorizeValidator : public OptInDispatch { const auto& graph = id_model.idGraph(IdMappingMode::EXACT); auto expr_path = ValGraphBFS::getExprsBetween( - graph, - graph.toGroups(tv->getMaybeAllocationDomain()), - graph.toGroups(std::vector{v_id})); + graph, + graph.toGroups(tv->getMaybeAllocationDomain()), + graph.toGroups(std::vector{v_id})) + .first; expr_path = reverse(expr_path); ValGroup cur_group = graph.toGroup(v_id); diff --git a/csrc/id_model/circular_buffer_indexing.cpp b/csrc/id_model/circular_buffer_indexing.cpp index 521d435ca34..a638793e198 100644 --- a/csrc/id_model/circular_buffer_indexing.cpp +++ b/csrc/id_model/circular_buffer_indexing.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include namespace nvfuser { diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index 1712cfacfae..b823d6141bf 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include namespace nvfuser { diff --git a/csrc/id_model/indexing_traversal.h b/csrc/id_model/indexing_traversal.h index f92062669cc..8381134f9d7 100644 --- a/csrc/id_model/indexing_traversal.h +++ b/csrc/id_model/indexing_traversal.h @@ -40,7 +40,7 @@ class IndexingTraversal : public ValGraphBFS { {from_groups.vector().begin(), from_groups.vector().end()}, {to_groups.vector().begin(), to_groups.vector().end()}); traversal.traverse(); - return traversal.getShortestExprPath(); + return traversal.getShortestExprPath().first; } using ValGraphBFS::isVisited; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 8758a8022cb..a74d32b6d67 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3722,9 +3722,10 @@ std::vector TensorDomain::allIDs() const { continue; } auto path = IRBFS::getExprsBetween( - {all_domains[i]->begin(), all_domains[i]->end()}, - {all_domains[j]->begin(), all_domains[j]->end()}, - false); + {all_domains[i]->begin(), all_domains[i]->end()}, + {all_domains[j]->begin(), all_domains[j]->end()}, + false) + .first; for (auto [expr, _] : path) { discovered_ids.pushBack( ir_utils::filterByType(expr->outputs())); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index c2f3ba1b33f..39eac95f417 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -826,8 +826,10 @@ CompareDomainResult compareDomains( toDelimitedString(dom1)); dom0.insert(dom0.end(), additional_ids.begin(), additional_ids.end()); - auto exprs = IRBFS::getExprsBetween( - {dom0.begin(), dom0.end()}, {dom1.begin(), dom1.end()}, false); + auto exprs = + IRBFS::getExprsBetween( + {dom0.begin(), dom0.end()}, {dom1.begin(), dom1.end()}, false) + .first; std::unordered_set frontier(dom0.begin(), dom0.end()); @@ -1285,4 +1287,22 @@ ForLoop* createRangeLoop(int64_t size) { return loop; } +TensorView* getTvOutput(const Expr* expr) { + for (auto out : expr->outputs()) { + if (auto tv = getTv(out)) { + return tv; + } + } + return nullptr; +} + +TensorView* getTvInput(const Expr* expr) { + for (auto inp : expr->inputs()) { + if (auto tv = getTv(inp)) { + return tv; + } + } + return nullptr; +} + } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 17d62716960..d0f70a1925e 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -727,4 +727,10 @@ int64_t getOperationCount(Val* val); // for (int i = 0; i < size; i++) ForLoop* createRangeLoop(int64_t size); +// Returns the first output of Expr that is a TensorView +TensorView* getTvOutput(const Expr*); + +// Returns the first input of Expr that is a TensorView +TensorView* getTvInput(const Expr*); + } // namespace nvfuser::ir_utils diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 5700e1d26cf..aa9c389cce3 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -1175,7 +1175,7 @@ std::vector IRBFS::getValsBetween( const std::vector& from, const std::vector& to) { auto path = - IRBFS::getExprsBetween(from, to, /*require_all_to_visited=*/false); + IRBFS::getExprsBetween(from, to, /*require_all_to_visited=*/false).first; VectorOfUniqueEntries unique_vals; for (auto [expr, _] : path) { @@ -1197,7 +1197,8 @@ std::vector IRBFS::getValsBetween( std::vector IRBFS::getDependenciesTo( const std::vector& vals, const std::vector& to) { - auto path = IRBFS::getExprsBetween(vals, to, /*require_all_to_visited=*/true); + auto path = + IRBFS::getExprsBetween(vals, to, /*require_all_to_visited=*/true).first; VectorOfUniqueEntries unique_vals; diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index fb8add34185..aadb743674d 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -600,7 +600,7 @@ class IRBFS // Find the shortest path from the from_groups_ to to_groups_ on a // given graph. Dependency between vals and exprs must be satisfied. // It is an error if no valid path is found. - static ExprPath getExprsBetween( + static std::pair getExprsBetween( const std::vector& from, const std::vector& to, bool require_all_to_visited = true) { diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index ea8bdafe656..32bd5b8fe2a 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 5b0eb416733..2e47b8aa69f 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -327,30 +327,33 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { return all_ancestors_of_ref_.has(tv_target_domain); })) { return ValGraphBFS::getExprsBetween( - graph(), - ref_id_groups_, - tv_target_domains, - /*require_all_to_visited=*/true, - Direction::Backward); + graph(), + ref_id_groups_, + tv_target_domains, + /*require_all_to_visited=*/true, + Direction::Backward) + .first; } // Find the forward path from the ancestors to the target tensor auto forward_path = ValGraphBFS::getExprsBetween( - graph(), - all_ancestors_of_ref_, - tv_target_domains, - /*require_all_to_visited=*/true, - Direction::Forward); + graph(), + all_ancestors_of_ref_, + tv_target_domains, + /*require_all_to_visited=*/true, + Direction::Forward) + .first; // Find the path from the ref to the forward path. auto inputs_of_forward_path = getInputsOfExprPath(graph(), forward_path); auto backward_path = ValGraphBFS::getExprsBetween( - graph(), - ref_id_groups_, - inputs_of_forward_path, - /*require_all_to_visited=*/true, - Direction::Backward); + graph(), + ref_id_groups_, + inputs_of_forward_path, + /*require_all_to_visited=*/true, + Direction::Backward) + .first; // Overall replay path = backward_path + forward_path ValGraphBFS::ExprPath replay_path; diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 024bc9a532f..2b139800eca 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -242,11 +242,12 @@ std::unordered_set ValGraphBFS::projectTo( std::unordered_set projection{from}; // Reverse order auto exprs = ValGraphBFS::getExprsBetween( - id_graph, - to, - {from}, - /*require_all_to_visited=*/false, - allowed_direction); + id_graph, + to, + {from}, + /*require_all_to_visited=*/false, + allowed_direction) + .first; while (!exprs.empty()) { const auto [expr, direction] = exprs.back(); exprs.pop_back(); diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index 58d58f1b9ff..42ab97326c8 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -199,7 +199,7 @@ class ValGraphBFS : public BFS< // Find the shortest path from the from_groups_ to to_groups_ on a // given graph. Dependency between vals and exprs must be satisfied. // It is an error if no valid path is found. - static ExprPath getExprsBetween( + static std::pair getExprsBetween( const ValGraph& graph, std::vector from, std::vector to, @@ -214,7 +214,7 @@ class ValGraphBFS : public BFS< bfs.traverse(); return bfs.getShortestExprPath(); } - static ExprPath getExprsBetween( + static std::pair getExprsBetween( const ValGraph& graph, const ValGroups& from, const ValGroups& to, diff --git a/tests/cpp/test_bfs.cpp b/tests/cpp/test_bfs.cpp index eeaad52e917..9bca6af3c16 100644 --- a/tests/cpp/test_bfs.cpp +++ b/tests/cpp/test_bfs.cpp @@ -58,7 +58,8 @@ TEST_F(BFSTest, ValGraphBFS1) { // Since the loop domains of tv0 and tv1 are grouped together, the // path between them is empty ExprPath tv1_to_tv0 = - ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_loop_groups); + ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_loop_groups) + .first; EXPECT_TRUE(tv1_to_tv0.empty()); // Traversal should fail if not all dependencies are met @@ -79,7 +80,8 @@ TEST_F(BFSTest, ValGraphBFS1) { // tv2 loop domain backward to its root and then forward from tv1 root to // tv1 loop domain. ExprPath tv2_to_tv1 = - ValGraphBFS::getExprsBetween(graph, tv2_loop_groups, tv1_loop_groups); + ValGraphBFS::getExprsBetween(graph, tv2_loop_groups, tv1_loop_groups) + .first; ExprPath tv2_to_tv1_ref; tv2_to_tv1_ref.emplace_back( @@ -124,7 +126,8 @@ TEST_F(BFSTest, ValGraphBFS2) { // Since the loop domains of tv0 and tv1 are grouped together, the // path between them is empty ExprPath tv1_to_tv0 = - ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_loop_groups); + ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_loop_groups) + .first; ExprPath tv1_to_tv0_ref; tv1_to_tv0_ref.emplace_back( @@ -141,7 +144,8 @@ TEST_F(BFSTest, ValGraphBFS2) { tv0_partial_groups.pushBack(graph.toGroup(tv0->axis(1))); tv0_partial_groups.pushBack(graph.toGroup(tv0->axis(2))); ExprPath tv1_to_tv0_partial = - ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_partial_groups); + ValGraphBFS::getExprsBetween(graph, tv1_loop_groups, tv0_partial_groups) + .first; EXPECT_EQ(tv1_to_tv0_partial, tv1_to_tv0_ref); } @@ -181,7 +185,7 @@ TEST_F(BFSTest, ValGraphBFS3) { ValGroups tv0_groups = graph.toGroups(tv0->getLoopDomain()); ExprPath tv4_to_tv0 = - ValGraphBFS::getExprsBetween(graph, tv4_groups, tv0_groups); + ValGraphBFS::getExprsBetween(graph, tv4_groups, tv0_groups).first; ExprPath tv4_to_tv0_ref; tv4_to_tv0_ref.emplace_back( graph.toGroup(tv1->axis(0)->definition()), Direction::Backward); @@ -230,7 +234,7 @@ TEST_F(BFSTest, ValGraphBFS4) { // and tv3, but the shortest path should be just one merge for tv1 ExprPath tv4_to_tv0 = - ValGraphBFS::getExprsBetween(graph, tv4_groups, tv0_groups); + ValGraphBFS::getExprsBetween(graph, tv4_groups, tv0_groups).first; ExprPath tv4_to_tv0_ref; tv4_to_tv0_ref.emplace_back( @@ -424,11 +428,12 @@ TEST_F(BFSTest, TraversalDirection) { // Shortest path from the input to tv7 should forward the second // path and then move one Merge backward auto shortest_path = ValGraphBFS::getExprsBetween( - exact_graph, - exact_graph.toGroups(tv0->getLogicalDomain()), - exact_graph.toGroups(tv7->getLogicalDomain()), - /*require_all_to_visited=*/true, - Direction::Undefined); + exact_graph, + exact_graph.toGroups(tv0->getLogicalDomain()), + exact_graph.toGroups(tv7->getLogicalDomain()), + /*require_all_to_visited=*/true, + Direction::Undefined) + .first; ValGraphBFS::ExprPath shortest_path_reference = { {exact_graph.toGroup(tv9->axis(-1)->definition()), Direction::Forward}, {exact_graph.toGroup(tv10->axis(-1)->definition()), Direction::Forward}, @@ -439,11 +444,12 @@ TEST_F(BFSTest, TraversalDirection) { // Forward only path should take tv1 through tv7 auto forward_path = ValGraphBFS::getExprsBetween( - exact_graph, - exact_graph.toGroups(tv0->getLogicalDomain()), - exact_graph.toGroups(tv7->getLogicalDomain()), - /*require_all_to_visited=*/true, - Direction::Forward); + exact_graph, + exact_graph.toGroups(tv0->getLogicalDomain()), + exact_graph.toGroups(tv7->getLogicalDomain()), + /*require_all_to_visited=*/true, + Direction::Forward) + .first; ValGraphBFS::ExprPath forward_path_reference = { {exact_graph.toGroup(tv1->axis(-1)->definition()), Direction::Forward}, {exact_graph.toGroup(tv2->axis(-1)->definition()), Direction::Forward}, @@ -458,11 +464,12 @@ TEST_F(BFSTest, TraversalDirection) { // Backward only path should not find anything auto backward_path = ValGraphBFS::getExprsBetween( - exact_graph, - exact_graph.toGroups(tv0->getLogicalDomain()), - exact_graph.toGroups(tv7->getLogicalDomain()), - /*require_all_to_visited=*/false, - Direction::Backward); + exact_graph, + exact_graph.toGroups(tv0->getLogicalDomain()), + exact_graph.toGroups(tv7->getLogicalDomain()), + /*require_all_to_visited=*/false, + Direction::Backward) + .first; EXPECT_TRUE(backward_path.empty()) << "Actual: " << backward_path; } diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index c040376e7b5..719dfedc56f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -6506,7 +6506,7 @@ TEST_F(NVFuserTest, AllIDsWithExtraLoopIDs1) { {tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, {tv2->getLoopDomain().begin(), tv2->getLoopDomain().end()}, false) - .empty()); + .first.empty()); // This ordering should find two exprs (i.e., the merge and the split). EXPECT_EQ( @@ -6514,7 +6514,7 @@ TEST_F(NVFuserTest, AllIDsWithExtraLoopIDs1) { {tv2->getLoopDomain().begin(), tv2->getLoopDomain().end()}, {tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, false) - .size(), + .first.size(), 2); std::unordered_set tv2_all_ids_ref; @@ -6579,13 +6579,16 @@ TEST_F(NVFuserTest, AllIDsWithExtraLoopIDs2) { {tv2->getLoopDomain().begin(), tv2->getLoopDomain().end()}, {tv2->getLogicalDomain().begin(), tv2->getLogicalDomain().end()}, false) - .empty()); + .first.empty()); // From the initial loop to the current loop should find the split expr - auto exprs_between = IRBFS::getExprsBetween( - {tv2->getInitialLoopDomain().begin(), tv2->getInitialLoopDomain().end()}, - {tv2->getLoopDomain().begin(), tv2->getLoopDomain().end()}, - false); + auto exprs_between = + IRBFS::getExprsBetween( + {tv2->getInitialLoopDomain().begin(), + tv2->getInitialLoopDomain().end()}, + {tv2->getLoopDomain().begin(), tv2->getLoopDomain().end()}, + false) + .first; EXPECT_EQ(exprs_between.size(), 1); EXPECT_EQ(exprs_between.front().first, tv2_split); diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index 710be9ce08a..9178f311b8c 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -215,9 +215,11 @@ TEST_F(LoopDomainSchedulingTest, ReshapeTraversalDirection) { } // Validate the history of tv5 loop IDs - auto tv5_loop_to_logical = IRBFS::getExprsBetween( - {tv5->getLoopDomain().begin(), tv5->getLoopDomain().end()}, - {tv5->getLogicalDomain().begin(), tv5->getLogicalDomain().end()}); + auto tv5_loop_to_logical = + IRBFS::getExprsBetween( + {tv5->getLoopDomain().begin(), tv5->getLoopDomain().end()}, + {tv5->getLogicalDomain().begin(), tv5->getLogicalDomain().end()}) + .first; // 1. Backward split (tv7 reshape) EXPECT_TRUE(