Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use IterDomain::split and IterDomain::merge #3327

Merged
merged 6 commits into from
Nov 4, 2024
Merged

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Nov 1, 2024

Instead of manually creating output IDs and directly constructing Split and Merge nodes, IterDomain::split and IterDomain::merge are used. This would reduce duplication of logic.

There's some slight change in iteration types of IDs generated when TensorView::rFactor is used. See the inline comment.

As of 8321cfd, everything seems fine.

Closes #2520

Instead of manually creating output IDs and directly constructing Split
and Merge nodes, IterDomain::split and IterDomain::merge are used. This
would reduce duplication of logic.
@naoyam
Copy link
Collaborator Author

naoyam commented Nov 1, 2024

!test --diff

static IterDomain* merge(
IterDomain* outer,
IterDomain* inner,
bool rfactor_domain = false);
std::optional<bool> rfactor_domain = std::nullopt,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this parameter to std::optional so that it's clear if the argument of false is intentional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no change of logic in this file.

@naoyam
Copy link
Collaborator Author

naoyam commented Nov 2, 2024

!test --diff

// rfactored domains. If it isn't involved in the rfactor, it's no
// longer a redunction domain
std::optional<IterType> outer_iter_type;
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, every output iter domain uses IterType::Iteration unless it's one of the rfactored loop IDs. This means that, for example:

auto t0 = [i0, r1, r2];
t0->merge(1, 2); // [i0, r1*r2]
t0->split(1, 4); // [i0, r1*r2/4, 4]
t1 = t0->rFactor(-1);

In this case, t1 would inherit all the transformations of t0 and the merge output would not be marked as a reduction domain.

Generally, when a tensor is rfactored, whether it's Iteration or Reduction seems to be blurred, and there's a case where a reduction iter domain is merged with a non-reduction iter domain. That seems to be one of the reasons IterType::merge wasn't used since the function asserts both inputs should have the same iter type.

In this PR, I tried to make it more consistent by looking at all IDs between the logical reduction IDs and the rfactored loop IDs. Any IDs, including the intermediate ones should remain reductions, so there should be no case of merging Iteration and Reduction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple repro just in case.

@naoyam naoyam marked this pull request as ready for review November 4, 2024 19:03
@naoyam
Copy link
Collaborator Author

naoyam commented Nov 4, 2024

!build

@naoyam naoyam requested a review from jacobhinkle November 4, 2024 19:07
@naoyam
Copy link
Collaborator Author

naoyam commented Nov 4, 2024

!test

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Nice change.

Comment on lines +57 to +61
// Due to rfactor transformations, the iter types of the outputs
// may not follow the default rule. For example, even if the input
// is a reduction iter domain, the outputs may not. To replay the
// original split expression, the output iter types need to be
// specified explicitly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this comment.

auto tv2 = tv1->rFactor({-1});

auto split = dynamic_cast<Split*>(tv2->axis(-1)->definition());
ASSERT_NE(split, nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also do EXPECT_TRUE(tv2->axis(-1)->isIteration()) to check that altering the intermediate domains to Reduction does not interfere with the factoring; i.e. that this is still only a partial sum.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, though it's actually tv2->axis(-1)->isReduction() and tv2->axis(-2)->isIteration since it's the innermost domain that is rfactored. I'll add an assertion.

@naoyam
Copy link
Collaborator Author

naoyam commented Nov 4, 2024

!build

1 similar comment
@xwang233
Copy link
Collaborator

xwang233 commented Nov 4, 2024

!build

@naoyam naoyam merged commit df888ac into main Nov 4, 2024
13 of 14 checks passed
@naoyam naoyam deleted the split_merge_cleanup branch November 4, 2024 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

refactor handle(Merge*/Split* s) to use IterDomain::merge/split
3 participants