diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index 3e9b35dbc94..c2a6127a861 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -64,25 +64,33 @@ std::optional IndexingTraversal:: /*additional_tvs=*/{}, /*build_graphs=*/false); - // If there's no resize in the producer and consumer tensors of this - // expr, it should not need this WAR. - std::vector resize_exprs; - for (const auto& [id, use_exprs] : local_model.idUses()) { - for (const auto& use_expr : use_exprs) { - if (auto resize = dynamic_cast(use_expr)) { - resize_exprs.push_back(resize); + // Gather all resize exprs for each of the inputs and outputs + std::unordered_map> tv_resize_map; + for (auto inp : ir_utils::filterByType(expr->inputs())) { + for (auto expr : inp->domain()->allExprs()) { + if (auto resize = dynamic_cast(expr)) { + tv_resize_map[inp].push_back(resize); + } + } + } + for (auto out : ir_utils::filterByType(expr->outputs())) { + for (auto expr : out->domain()->allExprs()) { + if (auto resize = dynamic_cast(expr)) { + tv_resize_map[out].push_back(resize); } } } - if (resize_exprs.empty()) { + // If there's no resize in the producer and consumer tensors of this + // expr, it should not need this WAR. + if (tv_resize_map.empty()) { return std::nullopt; } // The indexing issue with resize may happen when a single iter - // domain is resized multiple times. In other words, there must be - // at least two connected resize exprs. If not, this WAR is not - // necessary. + // domain is resized multiple times between a producer and a + // consumer. In other words, there must be at least two connected + // resize exprs. If not, this WAR is not necessary. // // Note that the actual indexing is done from the loop IDs, which // might be promoted to IDs outside of this particular expr. Thus, @@ -92,32 +100,63 @@ std::optional IndexingTraversal:: // promotion should not further add resize exprs, it is sufficient // to analyze only the IDs of this expr. - // Shortcut for a common case to avoid building the graph below - if (resize_exprs.size() < 2) { - return std::nullopt; - } - const auto& local_graph = local_model.buildAlmostExactGraph(); - // See if these resize expr groups are connected. Note that in the - // current default scheduling method, any tensor ops using resize - // should only show up with a fusion input as its input, so there - // must be no chained resize ops. The default scheduling, this - // function should not move beyond this point. In the case of the - // new resize scheduler that is currently under development will - // have multiple chained resize ops, but the scheduler should - // explicitly set the loop domain such that no promotion would - // happen, thus avoiding hitting the assertion down below. - ExprGroups resize_groups = local_graph.toGroups(resize_exprs); + // The below analysis is done for each producer-consumer pair, so it + // can be a rather expensive analysis, but in practice most + // cases should just bail out at the first if condition + + auto isSingleIdResizedMultipleTimes = [&](TensorView* inp, + TensorView* out) -> bool { + auto num_resizes = tv_resize_map[inp].size() + tv_resize_map[out].size(); + if (num_resizes < 2) { + return false; + } + + std::vector resize_exprs; + resize_exprs.reserve(num_resizes); + resize_exprs.insert( + resize_exprs.end(), + tv_resize_map[inp].begin(), + tv_resize_map[inp].end()); + resize_exprs.insert( + resize_exprs.end(), + tv_resize_map[out].begin(), + tv_resize_map[out].end()); + + // See if these resize expr groups are connected. Note that in the + // current default scheduling method, any tensor ops using resize + // should only show up with a fusion input as its input, so there + // must be no chained resize ops. The default scheduling, this + // function should not move beyond this point. In the case of the + // new resize scheduler that is currently under development will + // have multiple chained resize ops, but the scheduler should + // explicitly set the loop domain such that no promotion would + // happen, thus avoiding hitting the assertion down below. + ExprGroups resize_groups = local_graph.toGroups(resize_exprs); + for (const auto i : c10::irange(resize_groups.size() - 1)) { + const auto resize_i = resize_groups.at(i); + std::vector other_resizes{ + resize_groups.begin() + i + 1, resize_groups.end()}; + auto reachable_nodes = getReachableNodesFrom( + {resize_i}, other_resizes, Direction::Undefined, local_graph); + if (!reachable_nodes.empty()) { + return true; + } + } + + return false; + }; + bool single_id_resized_multiple_times = false; - for (const auto i : c10::irange(resize_groups.size() - 1)) { - const auto resize_i = resize_groups.at(i); - std::vector other_resizes{ - resize_groups.begin() + i + 1, resize_groups.end()}; - auto reachable_nodes = getReachableNodesFrom( - {resize_i}, other_resizes, Direction::Undefined, local_graph); - if (!reachable_nodes.empty()) { - single_id_resized_multiple_times = true; + for (auto out : ir_utils::filterByType(expr->outputs())) { + for (auto inp : ir_utils::filterByType(expr->inputs())) { + if (isSingleIdResizedMultipleTimes(inp, out)) { + single_id_resized_multiple_times = true; + break; + } + } + if (single_id_resized_multiple_times) { break; } } diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 549fb08ee84..eb64da645f3 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -5408,7 +5408,7 @@ TEST_F(IndexingTest, ResizeRotation) { // Repro of issue #3505. The indexing WAR for resize triggered an // assertion due to loop promotion. -TEST_F(IndexingTest, Issue3505) { +TEST_F(IndexingTest, Issue3505Repro1) { Fusion fusion; FusionGuard fg(&fusion); @@ -5452,6 +5452,54 @@ TEST_F(IndexingTest, Issue3505) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +// Another repro of issue #3505 +TEST_F(IndexingTest, Issue3505Repro2) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int64_t i0 = 8; + const int64_t i1 = 2; + const auto zero = fusion.zeroVal(); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeContigConcreteTensor({i0}); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({i1, i0 / 2}); + fusion.addInput(tv1); + + // Left half + auto tv2 = slice(tv0, {{zero, IrBuilder::create(i0 / 2)}}); + // Right half + auto tv3 = slice( + tv0, {{IrBuilder::create(i0 / 2), IrBuilder::create(i0)}}); + + // The two inputs of this add expression have a resize of the same + // ID, but this should not mean the resize war path is required. + auto tv4 = add(tv2, tv3); + auto tv5 = broadcast(tv4, {true, false}); + auto tv6 = add(tv1, tv5); + fusion.addOutput(tv6); + + // Make loop promotion required + for (auto tv : {tv2, tv3, tv4, tv5, tv6}) { + tv->flatten(); + } + inlineMost(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({i0}, options); + auto t1 = at::randn({i1, i0 / 2}, options); + std::vector inputs{t0, t1}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + TEST_F(IndexingTest, AlmostExactIndexingUpdate) { EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});