Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IterDomain::split and IterDomain::merge #3327

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,23 @@ class NVF_API IterDomain : public Val {
static std::vector<IterDomain*> clone(
const std::vector<IterDomain*>& domains);

//! When `rfactor_domain` is true, also set the `is_rfactor_domain_` flag of
//! the result IterDomain.
//! The optional parameters of rfactor_domain and iter_type can be
//! used to override the default behavior.
static IterDomain* merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this parameter to std::optional so that it's clear if the argument of false is intentional.

std::optional<IterType> iter_type = std::nullopt);

//! The optional parameters of rfactor_domain, outer_iter_type and
//! inner_iter_type can be used to override the default behavior.
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
std::optional<IterType> outer_iter_type = std::nullopt,
std::optional<IterType> inner_iter_type = std::nullopt);

//! Resize an IterDomain by expanding both the left and right sides
//! by given widths. The resulting IterDomain has an extent of
Expand Down
70 changes: 49 additions & 21 deletions csrc/ir/nodes.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no change of logic in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,8 @@ IterDomain* IterDomain::cloneWithoutRFactor(bool map_with_original) {
IterDomain* IterDomain::merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> iter_type) {
NVF_CHECK(
outer->isReduction() == inner->isReduction(),
"Merging IterDomains requires that their iteration types match. ",
Expand All @@ -2563,24 +2564,33 @@ IterDomain* IterDomain::merge(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();
if (!iter_type.has_value()) {
iter_type = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
}
if (outer->isBroadcast() && inner->isBroadcast()) {
iter_type = IterType::Broadcast;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
iter_type = IterType::Iteration;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
itype = IterType::GatherScatter;
if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::GatherScatter ||
inner->getIterType() == IterType::GatherScatter)) {
iter_type = IterType::GatherScatter;
}
}

Val* expanded_extent = nullptr;
Expand All @@ -2606,8 +2616,8 @@ IterDomain* IterDomain::merge(
IterDomainBuilder(outer->container()->zeroVal(), merged_id_size)
.parallel_type(outer->getParallelType())
.expanded_extent(expanded_extent)
.iter_type(itype)
.is_rfactor_domain(rfactor_domain)
.iter_type(*iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Merge>(
Expand All @@ -2620,7 +2630,9 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool rfactor_domain) {
std::optional<bool> rfactor_domain,
std::optional<IterType> outer_iter_type,
std::optional<IterType> inner_iter_type) {
NVF_CHECK(
factor->isIntegralScalar(), "Cannot split by non-integer value ", factor);

Expand All @@ -2631,6 +2643,22 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
expanded_remainder = ceilDiv(in->expandedExtent(), factor);
}

// By default, if not specified, don't create rfactor
// outputs. Reshape transformations should propagate the flag, which
// should explicitly specify the flag
if (!rfactor_domain.has_value()) {
rfactor_domain = false;
}

// If not specified, inherit these properties from the input iter domain
if (!outer_iter_type.has_value()) {
outer_iter_type = in->getIterType();
}

if (!inner_iter_type.has_value()) {
inner_iter_type = in->getIterType();
}

// outer loop IterDomain
IterDomain* ido =
IterDomainBuilder(
Expand All @@ -2639,8 +2667,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*outer_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

// inner loop IterDomain
Expand All @@ -2651,8 +2679,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->hasExpandedExtent() && !inner_split ? expanded_remainder
: nullptr)
.parallel_type(in->getParallelType())
.iter_type(in->getIterType())
.is_rfactor_domain(rfactor_domain)
.iter_type(*inner_iter_type)
.is_rfactor_domain(*rfactor_domain)
.build();

IrBuilder::createInContainer<Split>(
Expand Down
41 changes: 12 additions & 29 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,17 @@ class ReplaySelf : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Manually replay the split, following the output of the operations.
// This is so rfactor ops are replayed correctly.
IterDomain* ido = IterDomainBuilder(s->outer())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? remainder : s->factor())
.build();

// inner IterDomain
IterDomain* idi = IterDomainBuilder(s->inner())
.start(s->container()->zeroVal())
.extent(s->innerSplit() ? s->factor() : remainder)
.build();

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
NVF_ERROR(mapped != nullptr);

NVF_ERROR(s->outer()->isRFactorProduct() == s->inner()->isRFactorProduct());

auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
s->outer()->isRFactorProduct(),
s->outer()->getIterType(),
s->inner()->getIterType());

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -107,16 +99,7 @@ class ReplaySelf : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

IterDomain* merged_id = IterDomainBuilder(m->out())
.start(m->container()->zeroVal())
.extent(merged_id_size)
.build();

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down
71 changes: 27 additions & 44 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ class ReplayRFactor : public ReplayTransformations {
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop node.");

// outer loop size
Val* remainder = ceilDiv(mapped->extent(), s->factor());

// Check if we need to mark the outputs as an logical domain meaning this
// transformation must be present in replays otherwise it breaks the compute
// definition of the fusion. Iter domains are actually not static, its the
Expand All @@ -119,32 +116,20 @@ class ReplayRFactor : public ReplayTransformations {
bool static_logical_outputs = static_logical_ids_.count(s->outer()) ||
static_logical_ids_.count(s->inner());

// Manually replay the split, making reduction = false and rfactor = true
// outer IterDomain
IterDomain* ido =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? remainder : s->factor())
.iter_type(
rfactor_axes_.count(s->outer()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();

// inner IterDomain
IterDomain* idi =
IterDomainBuilder(
s->container()->zeroVal(),
s->innerSplit() ? s->factor() : remainder)
.iter_type(
rfactor_axes_.count(s->inner()) ? IterType::Reduction
: IterType::Iteration)
.is_rfactor_domain(static_logical_outputs)
.build();
auto outer_iter_type = rfactor_axes_.count(s->outer())
? IterType::Reduction
: IterType::Iteration;
auto inner_iter_type = rfactor_axes_.count(s->inner())
? IterType::Reduction
: IterType::Iteration;

// Generate the split node
IrBuilder::createInContainer<Split>(
s->container(), ido, idi, mapped, s->factor(), s->innerSplit());
auto [ido, idi] = IterDomain::split(
mapped,
s->factor(),
s->innerSplit(),
static_logical_outputs,
outer_iter_type,
inner_iter_type);

// Remove mapped id from loop IDs
loop_ids_.erase(mapped);
Expand Down Expand Up @@ -182,23 +167,21 @@ class ReplayRFactor : public ReplayTransformations {
id_inner_mapped,
" however one or both are not loop nodes.");

Val* merged_id_size =
mul(id_outer_mapped->extent(), id_inner_mapped->extent());

bool is_bcast =
id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast();
auto iter_type = rfactor_axes_.count(m->out())
? IterType::Reduction
: (is_bcast ? IterType::Broadcast : IterType::Iteration);

IterDomain* merged_id =
IterDomainBuilder(m->container()->zeroVal(), merged_id_size)
.iter_type(iter_type)
.is_rfactor_domain(static_logical_ids_.count(m->out()))
.build();
// Let IterDomain::merge determine the correct IterType, except
// when m->out is a reduction domain. If it isn't involved in
// the rfactor, it's no longer a redunction domain
std::optional<IterType> iter_type;
if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) !=
target_domain_.end() &&
m->out()->isReduction() && !rfactor_axes_.count(m->out())) {
iter_type = IterType::Iteration;
}

IrBuilder::createInContainer<Merge>(
m->container(), merged_id, id_outer_mapped, id_inner_mapped);
IterDomain* merged_id = IterDomain::merge(
id_outer_mapped,
id_inner_mapped,
static_logical_ids_.count(m->out()),
iter_type);

// Remove inputs from the loop IDs
loop_ids_.erase(id_outer_mapped);
Expand Down
Loading