From 0f89b96be7a68836e0930aac61c991818edce598 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 31 Oct 2024 22:45:48 -0700 Subject: [PATCH 1/6] Use IterDomain::split and IterDomain::merge Instead of manually creating output IDs and directly constructing Split and Merge nodes, IterDomain::split and IterDomain::merge are used. This would reduce duplication of logic. --- csrc/ir/internal_base_nodes.h | 13 +++++-- csrc/ir/nodes.cpp | 70 +++++++++++++++++++++++----------- csrc/transform_replay.cpp | 41 ++++++-------------- csrc/transform_rfactor.cpp | 71 +++++++++++++---------------------- 4 files changed, 97 insertions(+), 98 deletions(-) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 6ac52ba0564..f9f422cd994 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val { static std::vector clone( const std::vector& domains); - //! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of - //! the result IterDomain. + //! The optional parameters of rfactor_domain and iter_type can be + //! used to override the default behavior. static IterDomain* merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional iter_type = std::nullopt); + //! The optional parameters of rfactor_domain, outer_iter_type and + //! inner_iter_type can be used to override the default behavior. static std::pair split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional outer_iter_type = std::nullopt, + std::optional inner_iter_type = std::nullopt); //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3c7be6b6262..827106132bb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) { IterDomain* IterDomain::merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional iter_type) { NVF_CHECK( outer->isReduction() == inner->isReduction(), "Merging IterDomains requires that their iteration types match. ", @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge( !outer->isStride() && !inner->isStride(), "No support for merging stride domains"); + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + Val* merged_id_size = mul(outer->extent(), inner->extent()); - IterType itype = outer->getIterType(); + if (!iter_type.has_value()) { + iter_type = outer->getIterType(); - if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; - } + if (outer->isBroadcast() && inner->isBroadcast()) { + iter_type = IterType::Broadcast; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; - } + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::Iteration || + inner->getIterType() == IterType::Iteration)) { + iter_type = IterType::Iteration; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::GatherScatter || - inner->getIterType() == IterType::GatherScatter)) { - itype = IterType::GatherScatter; + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::GatherScatter || + inner->getIterType() == IterType::GatherScatter)) { + iter_type = IterType::GatherScatter; + } } Val* expanded_extent = nullptr; @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge( IterDomainBuilder(outer->container()->zeroVal(), merged_id_size) .parallel_type(outer->getParallelType()) .expanded_extent(expanded_extent) - .iter_type(itype) - .is_rfactor_domain(rfactor_domain) + .iter_type(*iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( @@ -2620,7 +2630,9 @@ std::pair IterDomain::split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional outer_iter_type, + std::optional inner_iter_type) { NVF_CHECK( factor->isIntegralScalar(), "Cannot split by non-integer value ", factor); @@ -2631,6 +2643,22 @@ std::pair IterDomain::split( expanded_remainder = ceilDiv(in->expandedExtent(), factor); } + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + + // If not specified, inherit these properties from the input iter domain + if (!outer_iter_type.has_value()) { + outer_iter_type = in->getIterType(); + } + + if (!inner_iter_type.has_value()) { + inner_iter_type = in->getIterType(); + } + // outer loop IterDomain IterDomain* ido = IterDomainBuilder( @@ -2639,8 +2667,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*outer_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); // inner loop IterDomain @@ -2651,8 +2679,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && !inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*inner_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 093715f92a8..b404134f7f8 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -52,25 +52,17 @@ class ReplaySelf : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - - // Manually replay the split, following the output of the operations. - // This is so rfactor ops are replayed correctly. - IterDomain* ido = IterDomainBuilder(s->outer()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? remainder : s->factor()) - .build(); - - // inner IterDomain - IterDomain* idi = IterDomainBuilder(s->inner()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? s->factor() : remainder) - .build(); - - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + NVF_ERROR(mapped != nullptr); + + NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct()); + + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + s->outer()->isRFactorProduct(), + s->outer()->getIterType(), + s->inner()->getIterType()); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -107,16 +99,7 @@ class ReplaySelf : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - IterDomain* merged_id = IterDomainBuilder(m->out()) - .start(m->container()->zeroVal()) - .extent(merged_id_size) - .build(); - - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 311bec23796..982f7cc62ec 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - // Check if we need to mark the outputs as an logical domain meaning this // transformation must be present in replays otherwise it breaks the compute // definition of the fusion. Iter domains are actually not static, its the @@ -119,32 +116,20 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); - // Manually replay the split, making reduction = false and rfactor = true - // outer IterDomain - IterDomain* ido = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? remainder : s->factor()) - .iter_type( - rfactor_axes_.count(s->outer()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); - - // inner IterDomain - IterDomain* idi = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? s->factor() : remainder) - .iter_type( - rfactor_axes_.count(s->inner()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); + auto outer_iter_type = rfactor_axes_.count(s->outer()) + ? IterType::Reduction + : IterType::Iteration; + auto inner_iter_type = rfactor_axes_.count(s->inner()) + ? IterType::Reduction + : IterType::Iteration; - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + static_logical_outputs, + outer_iter_type, + inner_iter_type); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -182,23 +167,21 @@ class ReplayRFactor : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - bool is_bcast = - id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast(); - auto iter_type = rfactor_axes_.count(m->out()) - ? IterType::Reduction - : (is_bcast ? IterType::Broadcast : IterType::Iteration); - - IterDomain* merged_id = - IterDomainBuilder(m->container()->zeroVal(), merged_id_size) - .iter_type(iter_type) - .is_rfactor_domain(static_logical_ids_.count(m->out())) - .build(); + // Let IterDomain::merge determine the correct IterType, except + // when m->out is a reduction domain. If it isn't involved in + // the rfactor, it's no longer a redunction domain + std::optional iter_type; + if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) != + target_domain_.end() && + m->out()->isReduction() && !rfactor_axes_.count(m->out())) { + iter_type = IterType::Iteration; + } - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge( + id_outer_mapped, + id_inner_mapped, + static_logical_ids_.count(m->out()), + iter_type); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); From 654b064e37ad3a9bf3578eac8f4671386edf464c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 1 Nov 2024 22:39:58 -0700 Subject: [PATCH 2/6] comment --- csrc/transform_replay.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index b404134f7f8..06e15929aa9 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -52,10 +52,13 @@ class ReplaySelf : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - NVF_ERROR(mapped != nullptr); - NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct()); + // Due to rfactor transformations, the iter types of the outputs + // may not follow the default rule. For example, even if the input + // is a reduction iter domain, the outputs may not. To replay the + // original split expression, the output iter types need to be + // specified explicitly. auto [ido, idi] = IterDomain::split( mapped, s->factor(), From 8321cfdd27ac15e9464f551e0c2ea0e5159357db Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 2 Nov 2024 12:51:54 -0700 Subject: [PATCH 3/6] make sure only rfactor-affected ids are changed --- csrc/transform_rfactor.cpp | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 982f7cc62ec..90a0bac0085 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -116,12 +116,15 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); - auto outer_iter_type = rfactor_axes_.count(s->outer()) - ? IterType::Reduction - : IterType::Iteration; - auto inner_iter_type = rfactor_axes_.count(s->inner()) - ? IterType::Reduction - : IterType::Iteration; + std::optional outer_iter_type; + if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { + outer_iter_type = IterType::Iteration; + } + + std::optional inner_iter_type; + if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) { + inner_iter_type = IterType::Iteration; + } auto [ido, idi] = IterDomain::split( mapped, @@ -171,9 +174,7 @@ class ReplayRFactor : public ReplayTransformations { // when m->out is a reduction domain. If it isn't involved in // the rfactor, it's no longer a redunction domain std::optional iter_type; - if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) != - target_domain_.end() && - m->out()->isReduction() && !rfactor_axes_.count(m->out())) { + if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) { iter_type = IterType::Iteration; } @@ -219,6 +220,9 @@ class ReplayRFactor : public ReplayTransformations { // The IterDomains in the original_domain that are being factored into the // first stage of the two stage reduction (the producer). std::unordered_set rfactor_axes_; + // All iter domains between the logical and the loop that the + // rfactor_axes_ depend on + std::unordered_set rfactor_dep_ids_; // Iter domains whose history cannot be changed as it would break rfactor // dependencies. std::unordered_set static_logical_ids_; @@ -245,6 +249,14 @@ class ReplayRFactor : public ReplayTransformations { rfactor_axes_(std::move(rfactor_axes)), static_logical_ids_(std::move(static_logical_ids)), logical_domain_(original_domain->logical()) { + const auto all_dep_vals = DependencyCheck::getAllValsBetween( + {original_domain->maybeRoot().begin(), + original_domain->maybeRoot().end()}, + {rfactor_axes_.begin(), rfactor_axes_.end()}); + + auto all_dep_ids = ir_utils::filterByType(all_dep_vals); + rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end()); + setErrorOnFailure(false); } }; From 90b654a5be6c42e7af76531eb60214b185ecc43a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 4 Nov 2024 10:52:52 -0800 Subject: [PATCH 4/6] comment cleanup --- csrc/transform_rfactor.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 90a0bac0085..07799487eb0 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -116,6 +116,10 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); + // Let IterDomain::split determine the correct IterType, except + // when the output is a reduction domain but not part of the + // rfactored domains. If it isn't involved in the rfactor, it's no + // longer a redunction domain std::optional outer_iter_type; if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { outer_iter_type = IterType::Iteration; @@ -171,8 +175,9 @@ class ReplayRFactor : public ReplayTransformations { " however one or both are not loop nodes."); // Let IterDomain::merge determine the correct IterType, except - // when m->out is a reduction domain. If it isn't involved in - // the rfactor, it's no longer a redunction domain + // when the output is a reduction domain but not part of the + // rfactored domains. If it isn't involved in the rfactor, it's no + // longer a redunction domain std::optional iter_type; if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) { iter_type = IterType::Iteration; From b774f6becfe91c6a1c8d3d81d900aaeb45231989 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 4 Nov 2024 11:16:07 -0800 Subject: [PATCH 5/6] repro --- tests/cpp/test_gpu3.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 9862dcb8b07..0df37008a00 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8874,6 +8874,31 @@ TEST_F(NVFuserTest, CpAsyncDataTypeBool) { auto cg_outputs = fe.runFusion({t0}); testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } + +// Intermediate IDs generaetd by rFactor should also remain +// reductions. See #3327 for more info. +TEST_F(NVFuserTest, RfactorIntermediateIDs) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1, 2}); + fusion.addOutput(tv1); + + tv1->merge(1, 2); + tv1->split(1, 4); + + auto tv2 = tv1->rFactor({-1}); + + auto split = dynamic_cast(tv2->axis(-1)->definition()); + ASSERT_NE(split, nullptr); + + auto merge_out = split->in(); + EXPECT_TRUE(merge_out->isReduction()); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From 5583acdac21d6b4a7b986a842590a31cd78d689e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 4 Nov 2024 12:45:59 -0800 Subject: [PATCH 6/6] add assertions --- tests/cpp/test_gpu3.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 0df37008a00..cd8a96e9ba9 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8892,6 +8892,9 @@ TEST_F(NVFuserTest, RfactorIntermediateIDs) { auto tv2 = tv1->rFactor({-1}); + EXPECT_TRUE(tv2->axis(-1)->isReduction()); + EXPECT_FALSE(tv2->axis(-2)->isReduction()); + auto split = dynamic_cast(tv2->axis(-1)->definition()); ASSERT_NE(split, nullptr);