diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 6ac52ba0564..f9f422cd994 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val { static std::vector clone( const std::vector& domains); - //! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of - //! the result IterDomain. + //! The optional parameters of rfactor_domain and iter_type can be + //! used to override the default behavior. static IterDomain* merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional iter_type = std::nullopt); + //! The optional parameters of rfactor_domain, outer_iter_type and + //! inner_iter_type can be used to override the default behavior. static std::pair split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain = false); + std::optional rfactor_domain = std::nullopt, + std::optional outer_iter_type = std::nullopt, + std::optional inner_iter_type = std::nullopt); //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3c7be6b6262..827106132bb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) { IterDomain* IterDomain::merge( IterDomain* outer, IterDomain* inner, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional iter_type) { NVF_CHECK( outer->isReduction() == inner->isReduction(), "Merging IterDomains requires that their iteration types match. ", @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge( !outer->isStride() && !inner->isStride(), "No support for merging stride domains"); + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + Val* merged_id_size = mul(outer->extent(), inner->extent()); - IterType itype = outer->getIterType(); + if (!iter_type.has_value()) { + iter_type = outer->getIterType(); - if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; - } + if (outer->isBroadcast() && inner->isBroadcast()) { + iter_type = IterType::Broadcast; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; - } + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::Iteration || + inner->getIterType() == IterType::Iteration)) { + iter_type = IterType::Iteration; + } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::GatherScatter || - inner->getIterType() == IterType::GatherScatter)) { - itype = IterType::GatherScatter; + if ((outer->isBroadcast() || inner->isBroadcast()) && + (outer->getIterType() == IterType::GatherScatter || + inner->getIterType() == IterType::GatherScatter)) { + iter_type = IterType::GatherScatter; + } } Val* expanded_extent = nullptr; @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge( IterDomainBuilder(outer->container()->zeroVal(), merged_id_size) .parallel_type(outer->getParallelType()) .expanded_extent(expanded_extent) - .iter_type(itype) - .is_rfactor_domain(rfactor_domain) + .iter_type(*iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( @@ -2620,7 +2630,9 @@ std::pair IterDomain::split( IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain) { + std::optional rfactor_domain, + std::optional outer_iter_type, + std::optional inner_iter_type) { NVF_CHECK( factor->isIntegralScalar(), "Cannot split by non-integer value ", factor); @@ -2631,6 +2643,22 @@ std::pair IterDomain::split( expanded_remainder = ceilDiv(in->expandedExtent(), factor); } + // By default, if not specified, don't create rfactor + // outputs. Reshape transformations should propagate the flag, which + // should explicitly specify the flag + if (!rfactor_domain.has_value()) { + rfactor_domain = false; + } + + // If not specified, inherit these properties from the input iter domain + if (!outer_iter_type.has_value()) { + outer_iter_type = in->getIterType(); + } + + if (!inner_iter_type.has_value()) { + inner_iter_type = in->getIterType(); + } + // outer loop IterDomain IterDomain* ido = IterDomainBuilder( @@ -2639,8 +2667,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*outer_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); // inner loop IterDomain @@ -2651,8 +2679,8 @@ std::pair IterDomain::split( in->hasExpandedExtent() && !inner_split ? expanded_remainder : nullptr) .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .is_rfactor_domain(rfactor_domain) + .iter_type(*inner_iter_type) + .is_rfactor_domain(*rfactor_domain) .build(); IrBuilder::createInContainer( diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 093715f92a8..b404134f7f8 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -52,25 +52,17 @@ class ReplaySelf : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - - // Manually replay the split, following the output of the operations. - // This is so rfactor ops are replayed correctly. - IterDomain* ido = IterDomainBuilder(s->outer()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? remainder : s->factor()) - .build(); - - // inner IterDomain - IterDomain* idi = IterDomainBuilder(s->inner()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? s->factor() : remainder) - .build(); - - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + NVF_ERROR(mapped != nullptr); + + NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct()); + + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + s->outer()->isRFactorProduct(), + s->outer()->getIterType(), + s->inner()->getIterType()); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -107,16 +99,7 @@ class ReplaySelf : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - IterDomain* merged_id = IterDomainBuilder(m->out()) - .start(m->container()->zeroVal()) - .extent(merged_id_size) - .build(); - - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 311bec23796..982f7cc62ec 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations { loop_ids_.find(mapped) != loop_ids_.end(), "Transform traversal failed, modified a node but it was not a loop node."); - // outer loop size - Val* remainder = ceilDiv(mapped->extent(), s->factor()); - // Check if we need to mark the outputs as an logical domain meaning this // transformation must be present in replays otherwise it breaks the compute // definition of the fusion. Iter domains are actually not static, its the @@ -119,32 +116,20 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); - // Manually replay the split, making reduction = false and rfactor = true - // outer IterDomain - IterDomain* ido = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? remainder : s->factor()) - .iter_type( - rfactor_axes_.count(s->outer()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); - - // inner IterDomain - IterDomain* idi = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? s->factor() : remainder) - .iter_type( - rfactor_axes_.count(s->inner()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_logical_outputs) - .build(); + auto outer_iter_type = rfactor_axes_.count(s->outer()) + ? IterType::Reduction + : IterType::Iteration; + auto inner_iter_type = rfactor_axes_.count(s->inner()) + ? IterType::Reduction + : IterType::Iteration; - // Generate the split node - IrBuilder::createInContainer( - s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); + auto [ido, idi] = IterDomain::split( + mapped, + s->factor(), + s->innerSplit(), + static_logical_outputs, + outer_iter_type, + inner_iter_type); // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -182,23 +167,21 @@ class ReplayRFactor : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - Val* merged_id_size = - mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - - bool is_bcast = - id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast(); - auto iter_type = rfactor_axes_.count(m->out()) - ? IterType::Reduction - : (is_bcast ? IterType::Broadcast : IterType::Iteration); - - IterDomain* merged_id = - IterDomainBuilder(m->container()->zeroVal(), merged_id_size) - .iter_type(iter_type) - .is_rfactor_domain(static_logical_ids_.count(m->out())) - .build(); + // Let IterDomain::merge determine the correct IterType, except + // when m->out is a reduction domain. If it isn't involved in + // the rfactor, it's no longer a redunction domain + std::optional iter_type; + if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) != + target_domain_.end() && + m->out()->isReduction() && !rfactor_axes_.count(m->out())) { + iter_type = IterType::Iteration; + } - IrBuilder::createInContainer( - m->container(), merged_id, id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge( + id_outer_mapped, + id_inner_mapped, + static_logical_ids_.count(m->out()), + iter_type); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped);