-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use IterDomain::split and IterDomain::merge #3327
Conversation
Instead of manually creating output IDs and directly constructing Split and Merge nodes, IterDomain::split and IterDomain::merge are used. This would reduce duplication of logic.
!test --diff |
static IterDomain* merge( | ||
IterDomain* outer, | ||
IterDomain* inner, | ||
bool rfactor_domain = false); | ||
std::optional<bool> rfactor_domain = std::nullopt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this parameter to std::optional
so that it's clear if the argument of false
is intentional.
csrc/ir/nodes.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be no change of logic in this file.
!test --diff |
// rfactored domains. If it isn't involved in the rfactor, it's no | ||
// longer a redunction domain | ||
std::optional<IterType> outer_iter_type; | ||
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, every output iter domain uses IterType::Iteration
unless it's one of the rfactored loop IDs. This means that, for example:
auto t0 = [i0, r1, r2];
t0->merge(1, 2); // [i0, r1*r2]
t0->split(1, 4); // [i0, r1*r2/4, 4]
t1 = t0->rFactor(-1);
In this case, t1
would inherit all the transformations of t0
and the merge output would not be marked as a reduction domain.
Generally, when a tensor is rfactored, whether it's Iteration
or Reduction
seems to be blurred, and there's a case where a reduction iter domain is merged with a non-reduction iter domain. That seems to be one of the reasons IterType::merge
wasn't used since the function asserts both inputs should have the same iter type.
In this PR, I tried to make it more consistent by looking at all IDs between the logical reduction IDs and the rfactored loop IDs. Any IDs, including the intermediate ones should remain reductions, so there should be no case of merging Iteration
and Reduction
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a simple repro just in case.
!build |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Nice change.
// Due to rfactor transformations, the iter types of the outputs | ||
// may not follow the default rule. For example, even if the input | ||
// is a reduction iter domain, the outputs may not. To replay the | ||
// original split expression, the output iter types need to be | ||
// specified explicitly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this comment.
auto tv2 = tv1->rFactor({-1}); | ||
|
||
auto split = dynamic_cast<Split*>(tv2->axis(-1)->definition()); | ||
ASSERT_NE(split, nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also do EXPECT_TRUE(tv2->axis(-1)->isIteration())
to check that altering the intermediate domains to Reduction
does not interfere with the factoring; i.e. that this is still only a partial sum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, though it's actually tv2->axis(-1)->isReduction()
and tv2->axis(-2)->isIteration
since it's the innermost domain that is rfactored. I'll add an assertion.
!build |
1 similar comment
!build |
Instead of manually creating output IDs and directly constructing Split and Merge nodes, IterDomain::split and IterDomain::merge are used. This would reduce duplication of logic.
There's some slight change in iteration types of IDs generated when
TensorView::rFactor
is used. See the inline comment.As of 8321cfd, everything seems fine.
Closes #2520