Skip to content

Commit

Permalink
Use IterDomain::split and IterDomain::merge
Browse files Browse the repository at this point in the history
Instead of manually creating output IDs and directly constructing Split
and Merge nodes, IterDomain::split and IterDomain::merge are used. This
would reduce duplication of logic.
  • Loading branch information
naoyam committed Nov 1, 2024
1 parent f08bd51 commit 0f89b96
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 98 deletions.
13 changes: 9 additions & 4 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val {
static std::vector<IterDomain*> clone(
const std::vector<IterDomain*>& domains);

//! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of
//! the result IterDomain.
//! The optional parameters of rfactor_domain and iter_type can be
//! used to override the default behavior.
static IterDomain* merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
std::optional<IterType> iter_type = std::nullopt);

//! The optional parameters of rfactor_domain, outer_iter_type and
//! inner_iter_type can be used to override the default behavior.
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
std::optional<IterType> outer_iter_type = std::nullopt,
std::optional<IterType> inner_iter_type = std::nullopt);

//! Resize an IterDomain by expanding both the left and right sides
//! by given widths. The resulting IterDomain has an extent of
Expand Down
70 changes: 49 additions & 21 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) {
IterDomain* IterDomain::merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> iter_type) {
NVF_CHECK(
outer->isReduction() == inner->isReduction(),
"Merging IterDomains requires that their iteration types match. ",
Expand All @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();
if (!iter_type.has_value()) {
iter_type = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
}
if (outer->isBroadcast() && inner->isBroadcast()) {
iter_type = IterType::Broadcast;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
iter_type = IterType::Iteration;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
itype = IterType::GatherScatter;
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
iter_type = IterType::GatherScatter;
}
}

Val* expanded_extent = nullptr;
Expand All @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge(
IterDomainBuilder(outer->container()->zeroVal(), merged_id_size)
.parallel_type(outer->getParallelType())
.expanded_extent(expanded_extent)
.iter_type(itype)
.is_rfactor_domain(rfactor_domain)
.iter_type(*iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Merge>(
Expand All @@ -2620,7 +2630,9 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> outer_iter_type,
std::optional<IterType> inner_iter_type) {
NVF_CHECK(
factor->isIntegralScalar(), "Cannot split by non-integer value ", factor);

Expand All @@ -2631,6 +2643,22 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
expanded_remainder = ceilDiv(in->expandedExtent(), factor);
}

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

// If not specified, inherit these properties from the input iter domain
if (!outer_iter_type.has_value()) {
outer_iter_type = in->getIterType();
}

if (!inner_iter_type.has_value()) {
inner_iter_type = in->getIterType();
}

// outer loop IterDomain
IterDomain* ido =
IterDomainBuilder(
Expand All @@ -2639,8 +2667,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*outer_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

// inner loop IterDomain
Expand All @@ -2651,8 +2679,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && !inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*inner_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Split>(
Expand Down
41 changes: 12 additions & 29 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,17 @@ class ReplaySelf : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Manually replay the split, following the output of the operations.
// This is so rfactor ops are replayed correctly.
IterDomain* ido = IterDomainBuilder(s->outer())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? remainder : s->factor())
.build();

// inner IterDomain
IterDomain* idi = IterDomainBuilder(s->inner())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? s->factor() : remainder)
.build();

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
NVF_ERROR(mapped != nullptr);

NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct());

auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
s->outer()->isRFactorProduct(),
s->outer()->getIterType(),
s->inner()->getIterType());

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -107,16 +99,7 @@ class ReplaySelf : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

IterDomain* merged_id = IterDomainBuilder(m->out())
.start(m->container()->zeroVal())
.extent(merged_id_size)
.build();

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down
71 changes: 27 additions & 44 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Check if we need to mark the outputs as an logical domain meaning this
// transformation must be present in replays otherwise it breaks the compute
// definition of the fusion. Iter domains are actually not static, its the
Expand All @@ -119,32 +116,20 @@ class ReplayRFactor : public ReplayTransformations {
bool static_logical_outputs = static_logical_ids_.count(s->outer()) ||
static_logical_ids_.count(s->inner());

// Manually replay the split, making reduction = false and rfactor = true
// outer IterDomain
IterDomain* ido =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? remainder : s->factor())
.iter_type(
rfactor_axes_.count(s->outer()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();

// inner IterDomain
IterDomain* idi =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? s->factor() : remainder)
.iter_type(
rfactor_axes_.count(s->inner()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();
auto outer_iter_type = rfactor_axes_.count(s->outer())
? IterType::Reduction
: IterType::Iteration;
auto inner_iter_type = rfactor_axes_.count(s->inner())
? IterType::Reduction
: IterType::Iteration;

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
static_logical_outputs,
outer_iter_type,
inner_iter_type);

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -182,23 +167,21 @@ class ReplayRFactor : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

bool is_bcast =
id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast();
auto iter_type = rfactor_axes_.count(m->out())
? IterType::Reduction
: (is_bcast ? IterType::Broadcast : IterType::Iteration);

IterDomain* merged_id =
IterDomainBuilder(m->container()->zeroVal(), merged_id_size)
.iter_type(iter_type)
.is_rfactor_domain(static_logical_ids_.count(m->out()))
.build();
// Let IterDomain::merge determine the correct IterType, except
// when m->out is a reduction domain. If it isn't involved in
// the rfactor, it's no longer a redunction domain
std::optional<IterType> iter_type;
if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) !=
target_domain_.end() &&
m->out()->isReduction() && !rfactor_axes_.count(m->out())) {
iter_type = IterType::Iteration;
}

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(
id_outer_mapped,
id_inner_mapped,
static_logical_ids_.count(m->out()),
iter_type);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down

0 comments on commit 0f89b96

Please sign in to comment.