-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept axis mapping when defining MmaOp #3391
Conversation
@@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) { | |||
} | |||
|
|||
} // namespace nvfuser::ir_utils | |||
|
|||
namespace nvfuser::MmaOpUtils { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was only needed for defining MmaOp::mAxes()
and friends, but:
- Those methods are never used so I removed them and
- We can reconstruct that information easily using
mma->axisMapping()
.
!test |
!test |
This caused failures in split-K
!test |
Using the wrong graph meant that we could not detect any id roles
!test |
After this PR, one thing we can do is specify the dimension order of the output of the Line 603 in d34553f
So in that case we can see how the AxisMapping is standing in for a root->logical reordering. Since there is one for each input operand this feels like another nice use case for read/write/compute domains as suggested by @zasdfgbnm for indexing ldmatrix.
|
csrc/ir/internal_nodes.h
Outdated
// corresponding position of each output axis in either the A or B input. | ||
// Positions are absolute and refer to the noReductions logical domain. NOTE: | ||
// -1 indicates that the axis does not exist, so Broadcast and Reduction | ||
// dimensions should not have position -1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide an example? For example, what does it mean if we have:
a_axes = {0, 2, 1, 3}
b_axes = {1, 0 , 3, 2}
does it mean a.axis(0)
is mapped to b.axis(1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll put this in a comment but yes, this would mean we would map the following sets:
{a.axis(0), b.axis(1), out.axis(0)}
{a.axis(2), b.axis(0), out.axis(1)}
{a.axis(1), b.axis(3), out.axis(2)}
{a.axis(3), b.axis(2), out.axis(3)}
const mma_utils::MatmulPattern& pattern = patterns.front(); | ||
|
||
IdModel id_model(&fusion); | ||
const ValGraph& permissive_graph = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use broadcast graph instead of permissive graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am actually planning to do this, but it means I need to update mma_utils since permissive is still used there, so I will do that in a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!build |
!build |
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers. For example, the output of `MmaOp` has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains. This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions. ## Detailed example In the test included in this PR, we have shared memory operand tensors that are scheduled like this ``` Inputs: T0_g___half[ iS0{i0}, iS1{i1} ] T1_g___half[ iS2{i3}, iS3{i4} ] Outputs: T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 ) %kernel_math { T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ) = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] ) T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ) = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] ) T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ) = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ), T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )) T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ) = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )); T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 ) = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming ) } // %kernel_math T0_g___half[ iS0{i0}, iS1{i1} ] logical domain : (iS0{i0}, iS1{i1}) contiguity: t t loop domain : (iS0{i0}, iS1{i1}) T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ) logical domain : (iS9{i0}, iS10{i1}) allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}) contiguity: t n t n t t t t t t Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128} Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16} Split: iS32{128} by factor 64 -> iS43{2}, iS44{64} Split: iS36{16} by factor 8 -> iB45{2}, iS46{8} Split: iS46{8} by factor 1 -> iS47{8}, iB48{1} Split: iS44{64} by factor 8 -> iS49{8}, iB50{8} Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8} loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}) T1_g___half[ iS2{i3}, iS3{i4} ] logical domain : (iS2{i3}, iS3{i4}) contiguity: t t loop domain : (iS2{i3}, iS3{i4}) T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ) logical domain : (iS11{i3}, iS12{i4}) allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}) contiguity: n t t n t t t t t t Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256} Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16} Split: iS40{256} by factor 64 -> iS53{4}, iS54{64} Split: iS42{16} by factor 8 -> iB55{2}, iS56{8} Split: iS56{8} by factor 1 -> iS57{8}, iB58{1} Split: iS54{64} by factor 8 -> iS59{8}, iB60{8} Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8} loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}) T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ) logical domain : (iS4{i1}, iS5{i4}, rS6{i0}) allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2}) contiguity: t t n t t t t t n n n Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128} Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256} Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16} Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64} Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256} Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2} loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}) ``` Notice that in `T4_s` that the loop broadcasts `bblockIdx.x33{1}` and `bS34{256}` are not derived from the logical domain. Instead, they are actually both the products of a `Split` involving an original "loop broadcast", although this is not currently shown in `fusion->printTransforms()`: ``` Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256} ``` In the predicate elimination pass with `T4_s` and producer and `T2_l` as consumer, the consumer ID `iblockIdx.x19{( ceilDiv(i4, 256) )}` _normally_ would map to a logical broadcast ID in `T4_s`, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices that `iblockIdx.x19{( ceilDiv(i4, 256) )}` is not actually used for indexing the producer `T4_s` so we do not need to worry about out-of-bounds accesses in this dimension. Without this PR, if we remove the check at https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37 then we generate the following code: ```c++ __global__ void nvfuser_none_f0_c0_r0_g0( Tensor<__half, 2, 2> T0, Tensor<__half, 2, 2> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, Tensor<__half, 2, 2> T3) { // ... nvfuser_index_t i4; i4 = 256 * ((nvfuser_index_t)blockIdx.x); nvfuser_index_t i7; i7 = 128 * ((nvfuser_index_t)blockIdx.y); nvfuser_index_t i19; i19 = 64 * ((nvfuser_index_t)threadIdx.y); bool b20; b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]); #pragma unroll 1 for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) { nvfuser_index_t i24; i24 = 16 * i23; // ... load operands ... __syncthreads(); if ((b20 && (i24 < T0.logical_size[0LL]))) { asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %130, 0;\n" " wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */); } asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory"); } asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory"); // ... epilogue and write outputs ... } ``` After this PR, the predicate around the `wgmma` call is removed and the `assertOnWarpOps` check can be restored. --------- Co-authored-by: Naoya Maruyama <[email protected]>
Stacked on #3414 This PR enables us to inline an MmaOp properly when its inputs are missing broadcast dimensions. We do this by always allowing inlining past loop broadcasts or their transforms. For example ``` tv0: logical [ iS1{i0} ] loop [ iS1{i0} bS5{1} ] tv1: logical [ iS2{i1} ] loop [ bS6{1} iS2{i1} ] tv2 = foo(tv0, tv1) logical [ iS3{i0} iS4{i1} ] ``` As long as the operation `foo` properly maps its arguments despite the missing logical dimensions (as `MmaOp` does as of #3391), then we should be able to fully inline this case because the loop broadcasts `bS5` and `bS6` are imaginary in the sense that they don't impact indexing.
This keeps the default interface of
fusedMultiplySum
but also adds an option to provide anMmaOp::AxisMapping
object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension.This PR does not alter the behavior of
mma_utils::MatmulPattern::translateToMmaOp
meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler.Fixes #3372
The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following:
moveInnerBroadcastLeft
and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule anMmaOp
without broadcasted inputs that has been manually defined using our automatic scheduler.MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true)
and enable that in the Hopper matmul scheduler. At this point it will be safe for us to acceptMatmulOp
andLinearOp
in the Hopper matmul scheduler.