Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inlining past loop broadcasts for MmaOp #3416

Merged
merged 50 commits into from
Dec 11, 2024
Merged

Allow inlining past loop broadcasts for MmaOp #3416

merged 50 commits into from
Dec 11, 2024

Conversation

jacobhinkle
Copy link
Collaborator

Stacked on #3414

This PR enables us to inline an MmaOp properly when its inputs are missing broadcast dimensions. We do this by always allowing inlining past loop broadcasts or their transforms. For example

tv0:
  logical [ iS1{i0} ]
  loop [ iS1{i0} bS5{1} ]
tv1:
  logical [ iS2{i1} ]
  loop [ bS6{1} iS2{i1} ]
tv2 = foo(tv0, tv1)
  logical [ iS3{i0} iS4{i1} ]

As long as the operation foo properly maps its arguments despite the missing logical dimensions (as MmaOp does as of #3391), then we should be able to fully inline this case because the loop broadcasts bS5 and bS6 are imaginary in the sense that they don't impact indexing.

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Nov 15, 2024

After this, we can actually generate a proper kernel and run it. I will rebase #3406 onto this and modify the test to compile and run in that PR so we can inspect the generated kernel there. We can keep this PR for discussing the inlining changes only.

@naoyam
Copy link
Collaborator

naoyam commented Nov 15, 2024

Does this only apply to broadcast IDs added by TensorView::broadcast()?

@jacobhinkle
Copy link
Collaborator Author

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

@naoyam
Copy link
Collaborator

naoyam commented Nov 15, 2024

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

@jacobhinkle
Copy link
Collaborator Author

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

To be safe I'll check the IterType when skipping.

Comment on lines 206 to 210
for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween(
{tv->domain()->additionalIDs().begin(),
tv->domain()->additionalIDs().end()},
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()},
/*require_all_to_visited=*/false)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This includes all IDs that are between additionalIDs() and loop domain. However, we could have something like this:

tv->broadcast(0, 16);
tv->merge(0);

In this case, we'll be merging the new broadcast ID with a pre-existing loop ID, so we should not ignore that. I think instead maybe what we should do is traverse from the root domain to the loop domain instead and the complement will then be the "pure" loop broadcasts which we can ignore.

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I suppose that would also automatically allow us to inline past regular broadcasts that are created using BroadcastOp since those new Broadcast IDs are not reachable from the root domain either, but we already inline past those IDs anyway I believe.

@zasdfgbnm
Copy link
Collaborator

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

No, I added it primarily for storing these new broadcasts.

@jacobhinkle
Copy link
Collaborator Author

In the latest pushed changes, I do a BFS from producer logical to producer allocation and from consumer root to consumer loop. This lets me collect the IDs that are used for indexing (assuming no shorter paths are discovered later). I then restrict the strictAreMapped check to the case where at least one of the producer or consumer ID is in that path. That covers loop broadcasts automatically as they're not used for indexing, and lets us inline around them if they appear in the same position as another ID that's not used in indexing that particular producer, as is the case for the mma use case I have in mind.

@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!test --diff

Base automatically changed from mma_predicate_elimination to main December 5, 2024 19:01
@jacobhinkle jacobhinkle marked this pull request as ready for review December 6, 2024 00:34
@jacobhinkle jacobhinkle requested a review from naoyam December 6, 2024 00:34
@jacobhinkle jacobhinkle changed the title Allow inlining past loop broadcasts Allow inlining past loop broadcasts for MmaOp Dec 6, 2024
@@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}
return producer->nDims();
} else {
std::unordered_set<ValGroup> loop_path_groups;
if (consumer->definition()->isA<MmaOp>()) {
// Get ValGroups between producer and consumer loop in the inlining graph
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make more sense to start by using a PairwiseLogicalDomainMap to find mapped ID groups between producer and consumer, then doing BFS to find the path from this to both loop domain val groups. That way we could avoid missing loop groups in both the producer and consumer. That case won't occur for MmaOp but it might be clearer than traversing from producer loop to consumer loop since it's not immediately clear why we go that direction.

Comment on lines 249 to 250
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the MmaOp case, we will hit this when c_id is not in loop_path_groups because it is an M or N dimension. Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set, to avoid mistakenly inlining in a case where a mapped producer dimension is in the same position as an unmapped consumer dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that an ignored producer ID is a broadcast?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set

If that's the case, do we need to check c_id? Isn't loop_path_groups.count(inliningGraph().toGroup(p_id)) sufficient?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that an ignored producer ID is a broadcast?

This is currently the case: we create a loop broadcast for the operands to MmaOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we assert that? If not broadcast, I'm not sure if it's safe to skip.

@@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}
return producer->nDims();
} else {
std::unordered_set<ValGroup> loop_path_groups;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since an empty set can mean it's just not set or there's indeed no val group, I think this would be clearer and less error-prone:

Suggested change
std::unordered_set<ValGroup> loop_path_groups;
std::optional<std::unordered_set<ValGroup>> loop_path_groups;

Comment on lines 248 to 250
if ((loop_path_groups.empty() ||
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think separating this logic would be easier to follow, like lines 222-224.

Suggested change
if ((loop_path_groups.empty() ||
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
// We can inline past consumer IDs that are not connected to the producer.
//
// For example, an MmaOp with no broadcasts could contain the following:
if (loop_path_groups.has_value() &&
(!loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
!loop_path_groups.count(inliningGraph().toGroup(c_id)))) {
continue;
}

Comment on lines 249 to 250
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that an ignored producer ID is a broadcast?

Comment on lines 249 to 250
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set

If that's the case, do we need to check c_id? Isn't loop_path_groups.count(inliningGraph().toGroup(p_id)) sufficient?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should have a few more tests. Can we create a test that should not get inlined even with the added condition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a section to the test that swaps the No and Ko axes in the mma result (which is scheduled Mo, No, Ko, ...). This is done on a copy of the scheduled fusion then I call inlineMost to check that we don't mistakenly inline past the unmapped No axis in tv0c.

Comment on lines +235 to +242
// tv0:
// root/logical: [ iS0, iS1 ]
// loop: [ iS0, bS7, iS1 ]
// tv1:
// root/logical: [ iS2, iS3 ]
// loop: [ bS8, iS2, iS3 ]
// tv2:
// root/logical/loop: [ iS4, iS5, rS6 ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these MmaOp inputs and output actually scheduled? Do their loop domains look like just shown above? I'm asking because if, for example, bS8 gets merged with iS2, it would be included in the BFS path between tv1 and tv2, so we wouldn't skip the merge output domain, but it isn't mapped with any of the loop ID of tv2 in the broadcast graph (because bS8 is not mapped), the inlining would be stopped at that point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how the schedule looks in the test:

T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 ),
         T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 ))

T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})

Also note that these expressions are not shown in the printout because they come from loop broadcasts bS15{1} and bS16{1}, but we split those broadcasts for each operand:

Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}
Split: bS16{1} by factor 128 -> bblockIdx.y37{1}, bS38{128}

The issue comes when we want to inline past these two outer IDs.

You are right that in some other case we might merge with some actual mapped domains like iS2 in your example. In that case it would be included so we would not skip it and we would not be able to inline, even though we probably should be able to. I'm not sure how to handle such a case. In your example if the merge(iS2, bS8) ID in a producer is aligned with a merge(iS9, iS10) in the consumer, we have no way to represent that bS8 should map to iS10 since bS8 is a loop broadcast and is not going to Broadcast map to anything, but that's really the type of relationship I'm trying to fake here -- in my case I'm faking that the loop broadcast M or N dimension will Broadcast map to the mma output M or N dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think this is good enough for now, but I wonder if we should extend the inlining graph. Suppose we have:

t0: [i0]
t1: [i1, i2]

// [i0, b3]
t0->broadcast(0);

In this case, b3 is the added broadcast ID. Since it's not part of the math definition, it won't be mapped with anything. However, if we have:

// [i4]
t0->merge(0, 1);

// [i5]
t1->merge(0, 1);

Maybe we should map i4 and i5 because we want to allow inlining of i4 to i5.

And maybe we should actually map b3 and i2 for the inlining purpose.

I think my point is also phrased as, instead of tweaking the inlining logic, should we define the inlining graph such that it would allow the patterns like the above? Again, not asking any change with this PR, but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if we could solve this problem by just moving the broadcast IDs to innermost. For T5,

loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})

If we move b37 and b38 to the innermost position, would that solve the inlining problem? Somewhat related prior fix: #2799

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have:

I think your example here sums up the issue. We would like to map b3 and i2 in the inlining graph but we have no way to express this wish currently. One option would be to add an optional argument to TensorView::broadcast letting us pass a vector of IterDomains that we'd like to inline with, then we could detect these just after building the Broadcast graph, copy the graph if any of those mappings are found, and perform those mappings, using the resulting graph as the inlining graph.

but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.

👍

moving the broadcast IDs to innermost

They are only broadcast in one of the operands. In the output tensor and in the other operand they are Iteration domains. I don't think we can move them because some of the ones that are Broadcast are the ones we need to inline; they are the outer split dimensions which we parallelize with BIDx/BIDy, i.e. these are the tile coordinates.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

letting us pass a vector of IterDomains that we'd like to inline with

By this I mean "a vector IterDomains we'd like to pretend that our new broadcast loop ID is broadcast mapped to"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this a little before:

https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L2533-L2535

Not recommended, though. If it's the only WAR, then sure, but otherwise, I found it not robust. For example, if IDs of x and y are registered as exactly mapped, that information may not be preserved when an iter domain is replaced by another (e.g., replaceSymbolicSizeds).

@naoyam
Copy link
Collaborator

naoyam commented Dec 10, 2024

I think my immediate concerns are just this and adding a few more tests. Let me know when it's ready for another review.

@jacobhinkle
Copy link
Collaborator Author

!test

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle merged commit f5f2ab5 into main Dec 11, 2024
17 checks passed
@jacobhinkle jacobhinkle deleted the mma_inlining branch December 11, 2024 14:41
jacobhinkle added a commit that referenced this pull request Dec 12, 2024
Stacked on #3410, #3414, and #3416

This simply enables compilation of the test which uses #3391.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants