-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use IterDomain::split and IterDomain::merge #3327
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no change of logic in this file. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,25 +52,20 @@ class ReplaySelf : public ReplayTransformations { | |
loop_ids_.find(mapped) != loop_ids_.end(), | ||
"Transform traversal failed, modified a node but it was not a loop node."); | ||
|
||
// outer loop size | ||
Val* remainder = ceilDiv(mapped->extent(), s->factor()); | ||
|
||
// Manually replay the split, following the output of the operations. | ||
// This is so rfactor ops are replayed correctly. | ||
IterDomain* ido = IterDomainBuilder(s->outer()) | ||
.start(s->container()->zeroVal()) | ||
.extent(s->innerSplit() ? remainder : s->factor()) | ||
.build(); | ||
|
||
// inner IterDomain | ||
IterDomain* idi = IterDomainBuilder(s->inner()) | ||
.start(s->container()->zeroVal()) | ||
.extent(s->innerSplit() ? s->factor() : remainder) | ||
.build(); | ||
|
||
// Generate the split node | ||
IrBuilder::createInContainer<Split>( | ||
s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); | ||
NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct()); | ||
|
||
// Due to rfactor transformations, the iter types of the outputs | ||
// may not follow the default rule. For example, even if the input | ||
// is a reduction iter domain, the outputs may not. To replay the | ||
// original split expression, the output iter types need to be | ||
// specified explicitly. | ||
Comment on lines
+57
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this comment. |
||
auto [ido, idi] = IterDomain::split( | ||
mapped, | ||
s->factor(), | ||
s->innerSplit(), | ||
s->outer()->isRFactorProduct(), | ||
s->outer()->getIterType(), | ||
s->inner()->getIterType()); | ||
|
||
// Remove mapped id from loop IDs | ||
loop_ids_.erase(mapped); | ||
|
@@ -107,16 +102,7 @@ class ReplaySelf : public ReplayTransformations { | |
id_inner_mapped, | ||
" however one or both are not loop nodes."); | ||
|
||
Val* merged_id_size = | ||
mul(id_outer_mapped->extent(), id_inner_mapped->extent()); | ||
|
||
IterDomain* merged_id = IterDomainBuilder(m->out()) | ||
.start(m->container()->zeroVal()) | ||
.extent(merged_id_size) | ||
.build(); | ||
|
||
IrBuilder::createInContainer<Merge>( | ||
m->container(), merged_id, id_outer_mapped, id_inner_mapped); | ||
IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped); | ||
|
||
// Remove inputs from the loop IDs | ||
loop_ids_.erase(id_outer_mapped); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations { | |
loop_ids_.find(mapped) != loop_ids_.end(), | ||
"Transform traversal failed, modified a node but it was not a loop node."); | ||
|
||
// outer loop size | ||
Val* remainder = ceilDiv(mapped->extent(), s->factor()); | ||
|
||
// Check if we need to mark the outputs as an logical domain meaning this | ||
// transformation must be present in replays otherwise it breaks the compute | ||
// definition of the fusion. Iter domains are actually not static, its the | ||
|
@@ -119,32 +116,27 @@ class ReplayRFactor : public ReplayTransformations { | |
bool static_logical_outputs = static_logical_ids_.count(s->outer()) || | ||
static_logical_ids_.count(s->inner()); | ||
|
||
// Manually replay the split, making reduction = false and rfactor = true | ||
// outer IterDomain | ||
IterDomain* ido = | ||
IterDomainBuilder( | ||
s->container()->zeroVal(), | ||
s->innerSplit() ? remainder : s->factor()) | ||
.iter_type( | ||
rfactor_axes_.count(s->outer()) ? IterType::Reduction | ||
: IterType::Iteration) | ||
.is_rfactor_domain(static_logical_outputs) | ||
.build(); | ||
// Let IterDomain::split determine the correct IterType, except | ||
// when the output is a reduction domain but not part of the | ||
// rfactored domains. If it isn't involved in the rfactor, it's no | ||
// longer a redunction domain | ||
std::optional<IterType> outer_iter_type; | ||
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, every output iter domain uses
In this case, Generally, when a tensor is rfactored, whether it's In this PR, I tried to make it more consistent by looking at all IDs between the logical reduction IDs and the rfactored loop IDs. Any IDs, including the intermediate ones should remain reductions, so there should be no case of merging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a simple repro just in case. |
||
outer_iter_type = IterType::Iteration; | ||
} | ||
|
||
// inner IterDomain | ||
IterDomain* idi = | ||
IterDomainBuilder( | ||
s->container()->zeroVal(), | ||
s->innerSplit() ? s->factor() : remainder) | ||
.iter_type( | ||
rfactor_axes_.count(s->inner()) ? IterType::Reduction | ||
: IterType::Iteration) | ||
.is_rfactor_domain(static_logical_outputs) | ||
.build(); | ||
std::optional<IterType> inner_iter_type; | ||
if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) { | ||
inner_iter_type = IterType::Iteration; | ||
} | ||
|
||
// Generate the split node | ||
IrBuilder::createInContainer<Split>( | ||
s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); | ||
auto [ido, idi] = IterDomain::split( | ||
mapped, | ||
s->factor(), | ||
s->innerSplit(), | ||
static_logical_outputs, | ||
outer_iter_type, | ||
inner_iter_type); | ||
|
||
// Remove mapped id from loop IDs | ||
loop_ids_.erase(mapped); | ||
|
@@ -182,23 +174,20 @@ class ReplayRFactor : public ReplayTransformations { | |
id_inner_mapped, | ||
" however one or both are not loop nodes."); | ||
|
||
Val* merged_id_size = | ||
mul(id_outer_mapped->extent(), id_inner_mapped->extent()); | ||
|
||
bool is_bcast = | ||
id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast(); | ||
auto iter_type = rfactor_axes_.count(m->out()) | ||
? IterType::Reduction | ||
: (is_bcast ? IterType::Broadcast : IterType::Iteration); | ||
|
||
IterDomain* merged_id = | ||
IterDomainBuilder(m->container()->zeroVal(), merged_id_size) | ||
.iter_type(iter_type) | ||
.is_rfactor_domain(static_logical_ids_.count(m->out())) | ||
.build(); | ||
// Let IterDomain::merge determine the correct IterType, except | ||
// when the output is a reduction domain but not part of the | ||
// rfactored domains. If it isn't involved in the rfactor, it's no | ||
// longer a redunction domain | ||
std::optional<IterType> iter_type; | ||
if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) { | ||
iter_type = IterType::Iteration; | ||
} | ||
|
||
IrBuilder::createInContainer<Merge>( | ||
m->container(), merged_id, id_outer_mapped, id_inner_mapped); | ||
IterDomain* merged_id = IterDomain::merge( | ||
id_outer_mapped, | ||
id_inner_mapped, | ||
static_logical_ids_.count(m->out()), | ||
iter_type); | ||
|
||
// Remove inputs from the loop IDs | ||
loop_ids_.erase(id_outer_mapped); | ||
|
@@ -236,6 +225,9 @@ class ReplayRFactor : public ReplayTransformations { | |
// The IterDomains in the original_domain that are being factored into the | ||
// first stage of the two stage reduction (the producer). | ||
std::unordered_set<IterDomain*> rfactor_axes_; | ||
// All iter domains between the logical and the loop that the | ||
// rfactor_axes_ depend on | ||
std::unordered_set<IterDomain*> rfactor_dep_ids_; | ||
// Iter domains whose history cannot be changed as it would break rfactor | ||
// dependencies. | ||
std::unordered_set<IterDomain*> static_logical_ids_; | ||
|
@@ -262,6 +254,14 @@ class ReplayRFactor : public ReplayTransformations { | |
rfactor_axes_(std::move(rfactor_axes)), | ||
static_logical_ids_(std::move(static_logical_ids)), | ||
logical_domain_(original_domain->logical()) { | ||
const auto all_dep_vals = DependencyCheck::getAllValsBetween( | ||
{original_domain->maybeRoot().begin(), | ||
original_domain->maybeRoot().end()}, | ||
{rfactor_axes_.begin(), rfactor_axes_.end()}); | ||
|
||
auto all_dep_ids = ir_utils::filterByType<IterDomain>(all_dep_vals); | ||
rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end()); | ||
|
||
setErrorOnFailure(false); | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this parameter to
std::optional
so that it's clear if the argument offalse
is intentional.