-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add slice tests to demonstrate manual scheduling #2898
Conversation
csrc/id_model/indexing_traversal.cpp
Outdated
auto resize = dynamic_cast<Resize*>(id->definition()); | ||
if (resize == nullptr) { | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick:
if (auto resize = dynamic_cast<Resize*>(id->definition()) {
tests/cpp/test_resize.cpp
Outdated
IrBuilder::create<Val>(-1, DataType::Index), | ||
IrBuilder::create<Val>(-1, DataType::Index)); | ||
|
||
tv2->setLoopDomain({tv2_loop_id}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naive question: why are we manually adding the Resize
with (-1, -1) here? Wasn't tv2->getLogicalDomain().at(0)
already equivalent to tv1->getLogicalDomain().at(0)
and hence can't we just re-use the loop domain as tv2->setLoopDomain(tv1->getLogicalDomain())
?
tv1->getLoopDomain().at(0), | ||
IrBuilder::create<Val>(1, DataType::Index), | ||
IrBuilder::create<Val>(1, DataType::Index)); | ||
tv1->setLoopDomain({tv1_padded}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I sorta see the need to create tv1_padded, Just for clarification, is there a requirement about the dependency of LoopDomain? i.e. can we use the domain of its consumer as the loop domain of the producer?
i.e. tv1->setLoopDomain(tv2->getLogicalDomain())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, it's because each tensor needs its own domains and exprs, at this moment.
void validateResize(Fusion* fusion) { | ||
auto fusion_vals = fusion->usedMathVals(); | ||
for (auto tv : ir_utils::filterByType<TensorView>(fusion_vals)) { | ||
// Make sure resize is only used as part of root to logical transformations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer holds
!build |
!build |
|
||
fusion.addOutput(tv3); | ||
|
||
tv2->setLoopDomain(tv2->getRootDomain()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify here, we would want tv2's loop domain to be identical to tv3 after the pad?
Here we can reuse the root domain, simply because the pad cancels out the slice.
So if we were to have a pad with a different pad widths, we'll want to clone tv3->getLogicalDomain()
, and then a manual Resize
to establish the pad width afterwards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, we can use the logical domain of the input or the logical domain of the output. Both are equivalent, but that equivalence relationship is not detected by IdModel. Adding some analysis to detect that relationship is part of the overall plan, but it isn't strictly required.
In this test, I'm using the input logical domain, which is mapped with the tv1 logical and the tv2 root domains. Since each tensor needs to use its own IDs,
tv2->setLoopDomain(tv2->getRootDomain())
is the simplest way to set the loop domain of tv2.
If the padding width of tv3 is different, then it'd be unclear what loop domain should be used, especially for the right-half case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the padding width of tv3 is different, then it'd be unclear what loop domain should be used, especially for the right-half case.
Does this mean that we require the final output of the fusion to be of the same shape (for the slice/pad dimension) with the original input with a slices+pads
chain?
Is this a requirement on patterns we support? This feels a big strange to have, i.e. looking at the pad
only test above PadScheduledLikeConsumer
, we don't have such restriction there.
Moving forward the chains of slice/pad(concat) below, how would we parse this restriction in that case, is the restriction translate to that each matching pair of slice/pad needs to cancel each other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a pad only, it should be fine. But if there's a slice and then a pad, there will be:
- input domain
- sliced domain of the input domain
- padded domain of the sliced domain
The requirement is that #1 and #3 correspond to the same domain. If not, it's unclear what domain to generate a loop for it. Technically, it should be a domain that has an extent of max(#1, #3)
, but we don't have a capability like this.
If there's only a pad,
- input domain
- padded domain of the input domain
left_half, | ||
fusion.zeroVal(), | ||
IrBuilder::create<Val>(shape[0] / 2, DataType::Index)); | ||
tv6->setLoopDomain(tv6_loop); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite following what's happening here. Looks like we replay the left_half
transform here (slice_resize -> pad_resize). My naive questions:
- Why do we need to have this replay in tv6_loop? since the slice+pad cancelled out each other, can't we just use tv2->getRootDomain()[0] directly?
- Follow up question is, if we do need to replay the transform for some reason, why can we just ignore the
right half
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you asking why this won't work?
std::vector<IterDomain*> tv6_loop{
tv2->getRootDomain()[0]->cloneWithoutRFactor(),
};
tv6->setLoopDomain(tv6_loop);
This won't work as this tv6_loop
domain is just a copy of the tv2
domain and is not connected with anything yet, i.e., it doesn't have any definition or use expr. We need to define how it is related to the existing domains of tv6
.
To connect the domain to tv6
, we can use either the left or right half. It won't matter. I should have mentioned that by a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please bear with more of my naive questions.
a copy of the tv2 domain and is not connected with anything yet, i.e., it doesn't have any definition or use expr.
I thought tv2->getRootDomain()
is consumed by a resize to create its logical domain. Since we have
auto tv2 = slice(tv1, {{fusion.zeroVal(), IrBuilder::create<Val>(shape[0] / 2)}});
We need to define how it is related to the existing domains of tv6
since we have a chain of slice -> pad -> binary_add
. With exact graph, if we view the IterDomain as a ValGroup, I think the connection is well established.
Of course this is just theoretical discussion ^^^.
We need to define how it is related to the existing domains of tv6
-> Does it mean our analysis today requires the chain of relationship needs to be resolved within the IterDomains of a single TensorDomain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought tv2->getRootDomain() is consumed by a resize to create its logical domain. Since we have
auto tv2 = slice(tv1, {{fusion.zeroVal(), IrBuilder::create(shape[0] / 2)}});
That is true, but tv2->getRootDomain()[0]->cloneWithoutRFactor()
is not. And since it's just a plain copy, it isn't mapped with anything.
since we have a chain of slice -> pad -> binary_add. With exact graph, if we view the IterDomain as a ValGroup, I think the connection is well established.
Again, that's true with tv2->getRootDomain()[0]
but not with its clone.
Does it mean our analysis today requires the chain of relationship needs to be resolved within the IterDomains of a single TensorDomain?
Yes and no, depending on what you're trying to get from the analysis. In this case, however, since the clone is not mapped with anything, there's nothing we can get from it without connecting it to other domains through replays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but tv2->getRootDomain()[0]->cloneWithoutRFactor() is not. And since it's just a plain copy
🤯 Ha, I see what I'm missing now. Thanks a lot for walking me through this 🙇
} | ||
|
||
// RoPE pattern except for the rotation | ||
TEST_F(ResizeTest, SliceSliceConcatConcat) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for myself, one last example to go through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for walking me through these examples!
Slice and concat patterns without rotation. See #2851.
Stacked on #2897, #2875
Closes #2870