Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IterDomain::split and IterDomain::merge #3327

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val {
static std::vector<IterDomain*> clone(
const std::vector<IterDomain*>& domains);

//! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of
//! the result IterDomain.
//! The optional parameters of rfactor_domain and iter_type can be
//! used to override the default behavior.
static IterDomain* merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this parameter to std::optional so that it's clear if the argument of false is intentional.

std::optional<IterType> iter_type = std::nullopt);

//! The optional parameters of rfactor_domain, outer_iter_type and
//! inner_iter_type can be used to override the default behavior.
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
std::optional<IterType> outer_iter_type = std::nullopt,
std::optional<IterType> inner_iter_type = std::nullopt);

//! Resize an IterDomain by expanding both the left and right sides
//! by given widths. The resulting IterDomain has an extent of
Expand Down
70 changes: 49 additions & 21 deletions csrc/ir/nodes.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no change of logic in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) {
IterDomain* IterDomain::merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> iter_type) {
NVF_CHECK(
outer->isReduction() == inner->isReduction(),
"Merging IterDomains requires that their iteration types match. ",
Expand All @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();
if (!iter_type.has_value()) {
iter_type = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
}
if (outer->isBroadcast() && inner->isBroadcast()) {
iter_type = IterType::Broadcast;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
iter_type = IterType::Iteration;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
itype = IterType::GatherScatter;
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
iter_type = IterType::GatherScatter;
}
}

Val* expanded_extent = nullptr;
Expand All @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge(
IterDomainBuilder(outer->container()->zeroVal(), merged_id_size)
.parallel_type(outer->getParallelType())
.expanded_extent(expanded_extent)
.iter_type(itype)
.is_rfactor_domain(rfactor_domain)
.iter_type(*iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Merge>(
Expand All @@ -2620,7 +2630,9 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> outer_iter_type,
std::optional<IterType> inner_iter_type) {
NVF_CHECK(
factor->isIntegralScalar(), "Cannot split by non-integer value ", factor);

Expand All @@ -2631,6 +2643,22 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
expanded_remainder = ceilDiv(in->expandedExtent(), factor);
}

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

// If not specified, inherit these properties from the input iter domain
if (!outer_iter_type.has_value()) {
outer_iter_type = in->getIterType();
}

if (!inner_iter_type.has_value()) {
inner_iter_type = in->getIterType();
}

// outer loop IterDomain
IterDomain* ido =
IterDomainBuilder(
Expand All @@ -2639,8 +2667,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*outer_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

// inner loop IterDomain
Expand All @@ -2651,8 +2679,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && !inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*inner_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Split>(
Expand Down
44 changes: 15 additions & 29 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,20 @@ class ReplaySelf : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Manually replay the split, following the output of the operations.
// This is so rfactor ops are replayed correctly.
IterDomain* ido = IterDomainBuilder(s->outer())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? remainder : s->factor())
.build();

// inner IterDomain
IterDomain* idi = IterDomainBuilder(s->inner())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? s->factor() : remainder)
.build();

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct());

// Due to rfactor transformations, the iter types of the outputs
// may not follow the default rule. For example, even if the input
// is a reduction iter domain, the outputs may not. To replay the
// original split expression, the output iter types need to be
// specified explicitly.
Comment on lines +57 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this comment.

auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
s->outer()->isRFactorProduct(),
s->outer()->getIterType(),
s->inner()->getIterType());

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -107,16 +102,7 @@ class ReplaySelf : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

IterDomain* merged_id = IterDomainBuilder(m->out())
.start(m->container()->zeroVal())
.extent(merged_id_size)
.build();

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down
86 changes: 43 additions & 43 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Check if we need to mark the outputs as an logical domain meaning this
// transformation must be present in replays otherwise it breaks the compute
// definition of the fusion. Iter domains are actually not static, its the
Expand All @@ -119,32 +116,27 @@ class ReplayRFactor : public ReplayTransformations {
bool static_logical_outputs = static_logical_ids_.count(s->outer()) ||
static_logical_ids_.count(s->inner());

// Manually replay the split, making reduction = false and rfactor = true
// outer IterDomain
IterDomain* ido =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? remainder : s->factor())
.iter_type(
rfactor_axes_.count(s->outer()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();
// Let IterDomain::split determine the correct IterType, except
// when the output is a reduction domain but not part of the
// rfactored domains. If it isn't involved in the rfactor, it's no
// longer a redunction domain
std::optional<IterType> outer_iter_type;
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, every output iter domain uses IterType::Iteration unless it's one of the rfactored loop IDs. This means that, for example:

auto t0 = [i0, r1, r2];
t0->merge(1, 2); // [i0, r1*r2]
t0->split(1, 4); // [i0, r1*r2/4, 4]
t1 = t0->rFactor(-1);

In this case, t1 would inherit all the transformations of t0 and the merge output would not be marked as a reduction domain.

Generally, when a tensor is rfactored, whether it's Iteration or Reduction seems to be blurred, and there's a case where a reduction iter domain is merged with a non-reduction iter domain. That seems to be one of the reasons IterType::merge wasn't used since the function asserts both inputs should have the same iter type.

In this PR, I tried to make it more consistent by looking at all IDs between the logical reduction IDs and the rfactored loop IDs. Any IDs, including the intermediate ones should remain reductions, so there should be no case of merging Iteration and Reduction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple repro just in case.

outer_iter_type = IterType::Iteration;
}

// inner IterDomain
IterDomain* idi =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? s->factor() : remainder)
.iter_type(
rfactor_axes_.count(s->inner()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();
std::optional<IterType> inner_iter_type;
if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) {
inner_iter_type = IterType::Iteration;
}

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
static_logical_outputs,
outer_iter_type,
inner_iter_type);

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -182,23 +174,20 @@ class ReplayRFactor : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

bool is_bcast =
id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast();
auto iter_type = rfactor_axes_.count(m->out())
? IterType::Reduction
: (is_bcast ? IterType::Broadcast : IterType::Iteration);

IterDomain* merged_id =
IterDomainBuilder(m->container()->zeroVal(), merged_id_size)
.iter_type(iter_type)
.is_rfactor_domain(static_logical_ids_.count(m->out()))
.build();
// Let IterDomain::merge determine the correct IterType, except
// when the output is a reduction domain but not part of the
// rfactored domains. If it isn't involved in the rfactor, it's no
// longer a redunction domain
std::optional<IterType> iter_type;
if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) {
iter_type = IterType::Iteration;
}

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(
id_outer_mapped,
id_inner_mapped,
static_logical_ids_.count(m->out()),
iter_type);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down Expand Up @@ -236,6 +225,9 @@ class ReplayRFactor : public ReplayTransformations {
// The IterDomains in the original_domain that are being factored into the
// first stage of the two stage reduction (the producer).
std::unordered_set<IterDomain*> rfactor_axes_;
// All iter domains between the logical and the loop that the
// rfactor_axes_ depend on
std::unordered_set<IterDomain*> rfactor_dep_ids_;
// Iter domains whose history cannot be changed as it would break rfactor
// dependencies.
std::unordered_set<IterDomain*> static_logical_ids_;
Expand All @@ -262,6 +254,14 @@ class ReplayRFactor : public ReplayTransformations {
rfactor_axes_(std::move(rfactor_axes)),
static_logical_ids_(std::move(static_logical_ids)),
logical_domain_(original_domain->logical()) {
const auto all_dep_vals = DependencyCheck::getAllValsBetween(
{original_domain->maybeRoot().begin(),
original_domain->maybeRoot().end()},
{rfactor_axes_.begin(), rfactor_axes_.end()});

auto all_dep_ids = ir_utils::filterByType<IterDomain>(all_dep_vals);
rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end());

setErrorOnFailure(false);
}
};
Expand Down
Loading