-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow inlining past loop broadcasts for MmaOp #3416
Conversation
After this, we can actually generate a proper kernel and run it. I will rebase #3406 onto this and modify the test to compile and run in that PR so we can inspect the generated kernel there. We can keep this PR for discussing the inlining changes only. |
Does this only apply to broadcast IDs added by |
Yes, that's the intention. I am using |
Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in |
To be safe I'll check the IterType when skipping. |
csrc/scheduler/tools/inlining.cpp
Outdated
for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( | ||
{tv->domain()->additionalIDs().begin(), | ||
tv->domain()->additionalIDs().end()}, | ||
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()}, | ||
/*require_all_to_visited=*/false)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This includes all IDs that are between additionalIDs()
and loop domain. However, we could have something like this:
tv->broadcast(0, 16);
tv->merge(0);
In this case, we'll be merging the new broadcast ID with a pre-existing loop ID, so we should not ignore that. I think instead maybe what we should do is traverse from the root domain to the loop domain instead and the complement will then be the "pure" loop broadcasts which we can ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I suppose that would also automatically allow us to inline past regular broadcasts that are created using BroadcastOp since those new Broadcast IDs are not reachable from the root domain either, but we already inline past those IDs anyway I believe.
No, I added it primarily for storing these new broadcasts. |
In the latest pushed changes, I do a BFS from producer logical to producer allocation and from consumer root to consumer loop. This lets me collect the IDs that are used for indexing (assuming no shorter paths are discovered later). I then restrict the |
This updates the NVF_THROW check to rule out the BroadcastOp case.
codediff passed
!test |
!test --diff |
csrc/scheduler/tools/inlining.cpp
Outdated
@@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( | |||
} | |||
return producer->nDims(); | |||
} else { | |||
std::unordered_set<ValGroup> loop_path_groups; | |||
if (consumer->definition()->isA<MmaOp>()) { | |||
// Get ValGroups between producer and consumer loop in the inlining graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would make more sense to start by using a PairwiseLogicalDomainMap
to find mapped ID groups between producer and consumer, then doing BFS to find the path from this to both loop domain val groups. That way we could avoid missing loop groups in both the producer and consumer. That case won't occur for MmaOp
but it might be clearer than traversing from producer loop to consumer loop since it's not immediately clear why we go that direction.
csrc/scheduler/tools/inlining.cpp
Outdated
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the MmaOp case, we will hit this when c_id
is not in loop_path_groups
because it is an M or N dimension. Note to self: I think we should probably assert that when this happens, both c_id
and p_id
are not found in that set, to avoid mistakenly inlining in a case where a mapped producer dimension is in the same position as an unmapped consumer dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it always the case that an ignored producer ID is a broadcast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set
If that's the case, do we need to check c_id
? Isn't loop_path_groups.count(inliningGraph().toGroup(p_id))
sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it always the case that an ignored producer ID is a broadcast?
This is currently the case: we create a loop broadcast for the operands to MmaOp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we assert that? If not broadcast, I'm not sure if it's safe to skip.
csrc/scheduler/tools/inlining.cpp
Outdated
@@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( | |||
} | |||
return producer->nDims(); | |||
} else { | |||
std::unordered_set<ValGroup> loop_path_groups; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since an empty set can mean it's just not set or there's indeed no val group, I think this would be clearer and less error-prone:
std::unordered_set<ValGroup> loop_path_groups; | |
std::optional<std::unordered_set<ValGroup>> loop_path_groups; |
csrc/scheduler/tools/inlining.cpp
Outdated
if ((loop_path_groups.empty() || | ||
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think separating this logic would be easier to follow, like lines 222-224.
if ((loop_path_groups.empty() || | |
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | |
loop_path_groups.count(inliningGraph().toGroup(c_id))) && | |
// We can inline past consumer IDs that are not connected to the producer. | |
// | |
// For example, an MmaOp with no broadcasts could contain the following: | |
if (loop_path_groups.has_value() && | |
(!loop_path_groups.count(inliningGraph().toGroup(p_id)) || | |
!loop_path_groups.count(inliningGraph().toGroup(c_id)))) { | |
continue; | |
} |
csrc/scheduler/tools/inlining.cpp
Outdated
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it always the case that an ignored producer ID is a broadcast?
csrc/scheduler/tools/inlining.cpp
Outdated
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set
If that's the case, do we need to check c_id
? Isn't loop_path_groups.count(inliningGraph().toGroup(p_id))
sufficient?
tests/cpp/test_matmul.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should have a few more tests. Can we create a test that should not get inlined even with the added condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a section to the test that swaps the No and Ko axes in the mma result (which is scheduled Mo, No, Ko, ...
). This is done on a copy of the scheduled fusion then I call inlineMost
to check that we don't mistakenly inline past the unmapped No
axis in tv0c
.
// tv0: | ||
// root/logical: [ iS0, iS1 ] | ||
// loop: [ iS0, bS7, iS1 ] | ||
// tv1: | ||
// root/logical: [ iS2, iS3 ] | ||
// loop: [ bS8, iS2, iS3 ] | ||
// tv2: | ||
// root/logical/loop: [ iS4, iS5, rS6 ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are these MmaOp inputs and output actually scheduled? Do their loop domains look like just shown above? I'm asking because if, for example, bS8
gets merged with iS2
, it would be included in the BFS path between tv1
and tv2
, so we wouldn't skip the merge output domain, but it isn't mapped with any of the loop ID of tv2
in the broadcast graph (because bS8
is not mapped), the inlining would be stopped at that point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how the schedule looks in the test:
T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
= mma(T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 ),
T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 ))
T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 )
logical domain : (iS9{i0}, iS10{i1})
allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
contiguity: t n t n t t t t t t
Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 )
logical domain : (iS11{i3}, iS12{i4})
allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
contiguity: n t t n t t t t t t
Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
contiguity: t t n t t t t t n n n
Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})
Also note that these expressions are not shown in the printout because they come from loop broadcasts bS15{1}
and bS16{1}
, but we split those broadcasts for each operand:
Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}
Split: bS16{1} by factor 128 -> bblockIdx.y37{1}, bS38{128}
The issue comes when we want to inline past these two outer IDs.
You are right that in some other case we might merge with some actual mapped domains like iS2
in your example. In that case it would be included so we would not skip it and we would not be able to inline, even though we probably should be able to. I'm not sure how to handle such a case. In your example if the merge(iS2, bS8)
ID in a producer is aligned with a merge(iS9, iS10)
in the consumer, we have no way to represent that bS8
should map to iS10
since bS8
is a loop broadcast and is not going to Broadcast map to anything, but that's really the type of relationship I'm trying to fake here -- in my case I'm faking that the loop broadcast M or N dimension will Broadcast map to the mma output M or N dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I think this is good enough for now, but I wonder if we should extend the inlining graph. Suppose we have:
t0: [i0]
t1: [i1, i2]
// [i0, b3]
t0->broadcast(0);
In this case, b3
is the added broadcast ID. Since it's not part of the math definition, it won't be mapped with anything. However, if we have:
// [i4]
t0->merge(0, 1);
// [i5]
t1->merge(0, 1);
Maybe we should map i4
and i5
because we want to allow inlining of i4
to i5
.
And maybe we should actually map b3
and i2
for the inlining purpose.
I think my point is also phrased as, instead of tweaking the inlining logic, should we define the inlining graph such that it would allow the patterns like the above? Again, not asking any change with this PR, but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if we could solve this problem by just moving the broadcast IDs to innermost. For T5
,
loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
If we move b37
and b38
to the innermost position, would that solve the inlining problem? Somewhat related prior fix: #2799
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suppose we have:
I think your example here sums up the issue. We would like to map b3
and i2
in the inlining graph but we have no way to express this wish currently. One option would be to add an optional argument to TensorView::broadcast
letting us pass a vector of IterDomain
s that we'd like to inline with, then we could detect these just after building the Broadcast graph, copy the graph if any of those mappings are found, and perform those mappings, using the resulting graph as the inlining graph.
but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.
👍
moving the broadcast IDs to innermost
They are only broadcast in one of the operands. In the output tensor and in the other operand they are Iteration domains. I don't think we can move them because some of the ones that are Broadcast are the ones we need to inline; they are the outer split dimensions which we parallelize with BIDx/BIDy, i.e. these are the tile coordinates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
letting us pass a vector of IterDomains that we'd like to inline with
By this I mean "a vector IterDomains we'd like to pretend that our new broadcast loop ID is broadcast mapped to"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this a little before:
https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L2533-L2535
Not recommended, though. If it's the only WAR, then sure, but otherwise, I found it not robust. For example, if IDs of x
and y
are registered as exactly mapped, that information may not be preserved when an iter domain is replaced by another (e.g., replaceSymbolicSizeds
).
I think my immediate concerns are just this and adding a few more tests. Let me know when it's ready for another review. |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!build |
Stacked on #3414
This PR enables us to inline an MmaOp properly when its inputs are missing broadcast dimensions. We do this by always allowing inlining past loop broadcasts or their transforms. For example
As long as the operation
foo
properly maps its arguments despite the missing logical dimensions (asMmaOp
does as of #3391), then we should be able to fully inline this case because the loop broadcastsbS5
andbS6
are imaginary in the sense that they don't impact indexing.