From 0b4d62a4c38089c650216f9f8e8aa91d609b39b4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 2 Dec 2024 21:42:27 -0800 Subject: [PATCH 1/4] repro and a fix --- csrc/bfs.h | 49 +++++++++++++++++++++------- csrc/id_model/indexing_traversal.cpp | 45 +++++++++++++++++++++---- tests/cpp/test_indexing.cpp | 46 ++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 18 deletions(-) diff --git a/csrc/bfs.h b/csrc/bfs.h index bcc6bf0902c..1fb7c13bc82 100644 --- a/csrc/bfs.h +++ b/csrc/bfs.h @@ -588,30 +588,57 @@ std::vector::type> getOutputsOfExprPath( return getInputsOfExprPath(reverse(path), get_inputs, get_outputs); } -// Given a set of vals, get all reachable ones from another set of vals +// Given a set of exprs and vals, get all reachable ones from another set of +// nodes template -std::vector getReachableValsFrom( - const std::vector& from, - const std::vector& vals, +std::vector getReachableNodesFrom( + const std::vector& from, + const std::vector& nodes, Direction allowed_direction = Direction::Undefined, const AdditionalArgs&... additional_args) { BFSType bfs( additional_args..., - {from.begin(), from.end()}, - {vals.begin(), vals.end()}, + from, + nodes, /*require_all_to_visited=*/false, allowed_direction); bfs.traverse(); - std::vector reachable_vals; - for (const auto& val : vals) { - if (bfs.isVisited(val) || - std::find(from.begin(), from.end(), val) != from.end()) { - reachable_vals.push_back(val); + std::vector reachable_nodes; + for (const auto& node : nodes) { + if (bfs.isVisited(node) || + std::find(from.begin(), from.end(), node) != from.end()) { + reachable_nodes.push_back(node); } } + return reachable_nodes; +} + +// Shortcut of getReachableNodesFrom for Vals +template +std::vector getReachableValsFrom( + const std::vector& from, + const std::vector& vals, + Direction allowed_direction = Direction::Undefined, + const AdditionalArgs&... additional_args) { + auto reachable_nodes = getReachableNodesFrom( + {from.begin(), from.end()}, + {vals.begin(), vals.end()}, + allowed_direction, + additional_args...); + + std::vector reachable_vals; + reachable_vals.reserve(reachable_nodes.size()); + std::transform( + reachable_nodes.begin(), + reachable_nodes.end(), + std::back_inserter(reachable_vals), + [](const auto& node) { + return std::get(node); + }); + return reachable_vals; } diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index a5f14fa5e5a..1413bb52069 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -62,18 +62,49 @@ std::optional IndexingTraversal:: // If there's no resize in the producer and consumer tensors of this // expr, it should not need this WAR. - if (std::none_of( - local_model.idUses().begin(), - local_model.idUses().end(), - [](const auto& kv) { - const VectorOfUniqueEntries& exprs = kv.second; - return !exprs.empty() && exprs.at(0)->isA(); - })) { + std::vector resize_exprs; + for (const auto& [id, use_exprs] : local_model.idUses()) { + for (const auto& use_expr : use_exprs) { + if (auto resize = dynamic_cast(use_expr)) { + resize_exprs.push_back(resize); + } + } + } + + if (resize_exprs.empty()) { + return std::nullopt; + } + + // The indexing issue with resize may happen when a single iter + // domain is resized multiple times. In other words, if there's only + // one resize, there's no problem with the default indexing path. + + // Shortcut for a common case to avoid building the graph below + if (resize_exprs.size() < 2) { return std::nullopt; } const auto& local_graph = local_model.buildAlmostExactGraph(); + ExprGroups resize_groups = local_graph.toGroups(resize_exprs); + + bool single_id_resized_multiple_times = false; + for (const auto i : c10::irange(resize_groups.size() - 1)) { + const auto resize_i = resize_groups.at(i); + std::vector other_resizes{ + resize_groups.begin() + i + 1, resize_groups.end()}; + auto reachable_nodes = getReachableNodesFrom( + {resize_i}, other_resizes, Direction::Undefined, local_graph); + if (!reachable_nodes.empty()) { + single_id_resized_multiple_times = true; + break; + } + } + + if (!single_id_resized_multiple_times) { + return std::nullopt; + } + // from_ids are loop domains, which are representative // domains of loop groups and not necessarily domains of any // of the producer and the consumer. In that case, find an ID out diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 265943df94b..549fb08ee84 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -5406,6 +5406,52 @@ TEST_F(IndexingTest, ResizeRotation) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +// Repro of issue #3505. The indexing WAR for resize triggered an +// assertion due to loop promotion. +TEST_F(IndexingTest, Issue3505) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int64_t i0 = 2; + const int64_t i1 = 4; + const int64_t i2 = 8; + const auto zero = fusion.zeroVal(); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeContigConcreteTensor({i1, i2}); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({i0, i1 / 2, i2 / 2}); + fusion.addInput(tv1); + + // One slice can reproduce the error but just to trigger the + // reachability check between multiple resize ops + auto tv2 = slice( + tv0, + {{zero, IrBuilder::create(i1 / 2)}, + {zero, IrBuilder::create(i2 / 2)}}); + auto tv3 = broadcast(tv2, {true, false, false}); + auto tv4 = add(tv1, tv3); + fusion.addOutput(tv4); + + for (auto tv : {tv2, tv3, tv4}) { + tv->flatten(); + } + inlineMost(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({i1, i2}, options); + auto t1 = at::randn({i0, i1 / 2, i2 / 2}, options); + std::vector inputs{t0, t1}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + TEST_F(IndexingTest, AlmostExactIndexingUpdate) { EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); From 446d998389f3f86af93980a717b4c43c72aec464 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 2 Dec 2024 23:33:21 -0800 Subject: [PATCH 2/4] comment --- csrc/id_model/indexing_traversal.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index 1413bb52069..459ee4c3099 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -55,6 +55,10 @@ std::optional IndexingTraversal:: auto consumer_tv = ir_utils::getTvOutput(expr); NVF_ERROR(consumer_tv != nullptr); + // First, try to limit the use of this WAR as much as possible since + // the WAR itself has a limitation that it assumes the loop domain + // is not promoted. + IdModel local_model( std::vector{consumer_tv->definition()}, /*additional_tvs=*/{}, @@ -76,8 +80,9 @@ std::optional IndexingTraversal:: } // The indexing issue with resize may happen when a single iter - // domain is resized multiple times. In other words, if there's only - // one resize, there's no problem with the default indexing path. + // domain is resized multiple times. In other words, there must be + // at least two connected resize exprs. If not, this WAR is not + // necessary. // Shortcut for a common case to avoid building the graph below if (resize_exprs.size() < 2) { @@ -86,8 +91,8 @@ std::optional IndexingTraversal:: const auto& local_graph = local_model.buildAlmostExactGraph(); + // See if these resize expr groups are connected ExprGroups resize_groups = local_graph.toGroups(resize_exprs); - bool single_id_resized_multiple_times = false; for (const auto i : c10::irange(resize_groups.size() - 1)) { const auto resize_i = resize_groups.at(i); @@ -101,6 +106,8 @@ std::optional IndexingTraversal:: } } + // No connection between the resize exprs is found, which they are + // all independent and there's no need to use this WAR if (!single_id_resized_multiple_times) { return std::nullopt; } From 5b364e60b814fc4aff869af4601fd4c591b12222 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 2 Dec 2024 23:49:05 -0800 Subject: [PATCH 3/4] comment --- csrc/id_model/indexing_traversal.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index 459ee4c3099..6a131f6b027 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -83,6 +83,14 @@ std::optional IndexingTraversal:: // domain is resized multiple times. In other words, there must be // at least two connected resize exprs. If not, this WAR is not // necessary. + // + // Note that the actual indexing is done from the loop IDs, which + // might be promoted to IDs outside of this particular expr. Thus, + // to get the true indexing path, the global IdModel may need to be + // used rather than the local model. Here, since we just need to + // know if there are multiple dependent resize exprs, and loop + // promotion should not further add resize exprs, it is sufficient + // to analyze only the IDs of this expr only. // Shortcut for a common case to avoid building the graph below if (resize_exprs.size() < 2) { From ccd0e80f5502bf55cd783bced5f6b97f8945ce32 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 3 Dec 2024 13:19:22 -0800 Subject: [PATCH 4/4] pr feedback --- csrc/id_model/indexing_traversal.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index 6a131f6b027..3e9b35dbc94 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -90,7 +90,7 @@ std::optional IndexingTraversal:: // used rather than the local model. Here, since we just need to // know if there are multiple dependent resize exprs, and loop // promotion should not further add resize exprs, it is sufficient - // to analyze only the IDs of this expr only. + // to analyze only the IDs of this expr. // Shortcut for a common case to avoid building the graph below if (resize_exprs.size() < 2) { @@ -99,7 +99,15 @@ std::optional IndexingTraversal:: const auto& local_graph = local_model.buildAlmostExactGraph(); - // See if these resize expr groups are connected + // See if these resize expr groups are connected. Note that in the + // current default scheduling method, any tensor ops using resize + // should only show up with a fusion input as its input, so there + // must be no chained resize ops. The default scheduling, this + // function should not move beyond this point. In the case of the + // new resize scheduler that is currently under development will + // have multiple chained resize ops, but the scheduler should + // explicitly set the loop domain such that no promotion would + // happen, thus avoiding hitting the assertion down below. ExprGroups resize_groups = local_graph.toGroups(resize_exprs); bool single_id_resized_multiple_times = false; for (const auto i : c10::irange(resize_groups.size() - 1)) {