diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 6ac52ba0564..f9f422cd994 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val { static std::vector clone( const std::vector& domains); - //! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of - //! the result IterDomain. + //! The optional parameters of rfactor_domain and iter_type can be + //! used to override the default behavior. static IterDomain* merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional iter_type = std::nullopt); + //! The optional parameters of rfactor_domain, outer_iter_type and + //! inner_iter_type can be used to override the default behavior. static std::pair split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional outer_iter_type = std::nullopt, + std::optional inner_iter_type = std::nullopt); //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3c7be6b6262..827106132bb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) { IterDomain* IterDomain::merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional iter_type) { NVF_CHECK( outer->isReduction() == inner->isReduction(), "Merging IterDomains requires that their iteration types match. ", @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge( !outer->isStride() && !inner->isStride(), "No support for merging stride domains"); + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + Val* merged_id_size = mul(outer->extent(), inner->extent()); - IterType itype = outer->getIterType(); + if (!iter_type.has_value()) { + iter_type = outer->getIterType(); - if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; - } + if (outer->isBroadcast() && inner->isBroadcast()) { + iter_type = IterType::Broadcast; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; - } + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::Iteration || + inner->getIterType() == IterType::Iteration)) { + iter_type = IterType::Iteration; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::GatherScatter || - inner->getIterType() == IterType::GatherScatter)) { - itype = IterType::GatherScatter; + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::GatherScatter || + inner->getIterType() == IterType::GatherScatter)) { + iter_type = IterType::GatherScatter; + } } Val* expanded_extent = nullptr; @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge( IterDomainBuilder(outer->container()->zeroVal(), merged_id_size) .parallel_type(outer->getParallelType()) .expanded_extent(expanded_extent) - .iter_type(itype) - .is_rfactor_domain(rfactor_domain) + .iter_type(*iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( @@ -2620,7 +2630,9 @@ std::pair IterDomain::split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional outer_iter_type, + std::optional inner_iter_type) { NVF_CHECK( factor->isIntegralScalar(), "Cannot split by non-integer value ", factor); @@ -2631,6 +2643,22 @@ std::pair IterDomain::split( expanded_remainder = ceilDiv(in->expandedExtent(), factor); } + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + + // If not specified, inherit these properties from the input iter domain + if (!outer_iter_type.has_value()) { + outer_iter_type = in->getIterType(); + } + + if (!inner_iter_type.has_value()) { + inner_iter_type = in->getIterType(); + } + // outer loop IterDomain IterDomain* ido = IterDomainBuilder( @@ -2639,8 +2667,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*outer_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); // inner loop IterDomain @@ -2651,8 +2679,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && !inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*inner_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 093715f92a8..06e15929aa9 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -52,25 +52,20 @@ class ReplaySelf : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - - // Manually replay the split, following the output of the operations. - // This is so rfactor ops are replayed correctly. - IterDomain* ido = IterDomainBuilder(s->outer()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? remainder : s->factor()) - .build(); - - // inner IterDomain - IterDomain* idi = IterDomainBuilder(s->inner()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? s->factor() : remainder) - .build(); - - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct()); + + // Due to rfactor transformations, the iter types of the outputs + // may not follow the default rule. For example, even if the input + // is a reduction iter domain, the outputs may not. To replay the + // original split expression, the output iter types need to be + // specified explicitly. + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + s->outer()->isRFactorProduct(), + s->outer()->getIterType(), + s->inner()->getIterType()); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -107,16 +102,7 @@ class ReplaySelf : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - IterDomain* merged_id = IterDomainBuilder(m->out()) - .start(m->container()->zeroVal()) - .extent(merged_id_size) - .build(); - - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 311bec23796..07799487eb0 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - // Check if we need to mark the outputs as an logical domain meaning this // transformation must be present in replays otherwise it breaks the compute // definition of the fusion. Iter domains are actually not static, its the @@ -119,32 +116,27 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); - // Manually replay the split, making reduction = false and rfactor = true - // outer IterDomain - IterDomain* ido = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? remainder : s->factor()) - .iter_type( - rfactor_axes_.count(s->outer()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); + // Let IterDomain::split determine the correct IterType, except + // when the output is a reduction domain but not part of the + // rfactored domains. If it isn't involved in the rfactor, it's no + // longer a redunction domain + std::optional outer_iter_type; + if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { + outer_iter_type = IterType::Iteration; + } - // inner IterDomain - IterDomain* idi = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? s->factor() : remainder) - .iter_type( - rfactor_axes_.count(s->inner()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); + std::optional inner_iter_type; + if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) { + inner_iter_type = IterType::Iteration; + } - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + static_logical_outputs, + outer_iter_type, + inner_iter_type); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -182,23 +174,20 @@ class ReplayRFactor : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - bool is_bcast = - id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast(); - auto iter_type = rfactor_axes_.count(m->out()) - ? IterType::Reduction - : (is_bcast ? IterType::Broadcast : IterType::Iteration); - - IterDomain* merged_id = - IterDomainBuilder(m->container()->zeroVal(), merged_id_size) - .iter_type(iter_type) - .is_rfactor_domain(static_logical_ids_.count(m->out())) - .build(); + // Let IterDomain::merge determine the correct IterType, except + // when the output is a reduction domain but not part of the + // rfactored domains. If it isn't involved in the rfactor, it's no + // longer a redunction domain + std::optional iter_type; + if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) { + iter_type = IterType::Iteration; + } - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge( + id_outer_mapped, + id_inner_mapped, + static_logical_ids_.count(m->out()), + iter_type); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); @@ -236,6 +225,9 @@ class ReplayRFactor : public ReplayTransformations { // The IterDomains in the original_domain that are being factored into the // first stage of the two stage reduction (the producer). std::unordered_set rfactor_axes_; + // All iter domains between the logical and the loop that the + // rfactor_axes_ depend on + std::unordered_set rfactor_dep_ids_; // Iter domains whose history cannot be changed as it would break rfactor // dependencies. std::unordered_set static_logical_ids_; @@ -262,6 +254,14 @@ class ReplayRFactor : public ReplayTransformations { rfactor_axes_(std::move(rfactor_axes)), static_logical_ids_(std::move(static_logical_ids)), logical_domain_(original_domain->logical()) { + const auto all_dep_vals = DependencyCheck::getAllValsBetween( + {original_domain->maybeRoot().begin(), + original_domain->maybeRoot().end()}, + {rfactor_axes_.begin(), rfactor_axes_.end()}); + + auto all_dep_ids = ir_utils::filterByType(all_dep_vals); + rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end()); + setErrorOnFailure(false); } }; diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 9862dcb8b07..cd8a96e9ba9 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8874,6 +8874,34 @@ TEST_F(NVFuserTest, CpAsyncDataTypeBool) { auto cg_outputs = fe.runFusion({t0}); testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } + +// Intermediate IDs generaetd by rFactor should also remain +// reductions. See #3327 for more info. +TEST_F(NVFuserTest, RfactorIntermediateIDs) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1, 2}); + fusion.addOutput(tv1); + + tv1->merge(1, 2); + tv1->split(1, 4); + + auto tv2 = tv1->rFactor({-1}); + + EXPECT_TRUE(tv2->axis(-1)->isReduction()); + EXPECT_FALSE(tv2->axis(-2)->isReduction()); + + auto split = dynamic_cast(tv2->axis(-1)->definition()); + ASSERT_NE(split, nullptr); + + auto merge_out = split->in(); + EXPECT_TRUE(merge_out->isReduction()); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser