Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept axis mapping when defining MmaOp #3391

Merged
merged 17 commits into from
Nov 12, 2024
Merged

Accept axis mapping when defining MmaOp #3391

merged 17 commits into from
Nov 12, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Nov 9, 2024

This keeps the default interface of fusedMultiplySum but also adds an option to provide an MmaOp::AxisMapping object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension.

This PR does not alter the behavior of mma_utils::MatmulPattern::translateToMmaOp meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler.

Fixes #3372

The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following:

  1. Demonstrate scheduling a Hopper matmul with unbroadcasted inputs manually. This should surface any bugs in the lowering of the MmaOp instruction when broadcasts are absent.
  2. Ensure that we don't depend on having broadcast dims in the Hopper matmul scheduler. For example, we will handle this case in moveInnerBroadcastLeft and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule an MmaOp without broadcasted inputs that has been manually defined using our automatic scheduler.
  3. Add an option MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true) and enable that in the Hopper matmul scheduler. At this point it will be safe for us to accept MatmulOp and LinearOp in the Hopper matmul scheduler.

@@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) {
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only needed for defining MmaOp::mAxes() and friends, but:

  1. Those methods are never used so I removed them and
  2. We can reconstruct that information easily using mma->axisMapping().

@jacobhinkle jacobhinkle changed the title [WIP] Accept axis mapping when defining MmaOp Accept axis mapping when defining MmaOp Nov 10, 2024
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!test

This caused failures in split-K
@jacobhinkle
Copy link
Collaborator Author

!test

Using the wrong graph meant that we could not detect any id roles
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle jacobhinkle marked this pull request as ready for review November 11, 2024 01:48
@jacobhinkle
Copy link
Collaborator Author

After this PR, one thing we can do is specify the dimension order of the output of the MmaOp independently from the inputs. When we translate MatmulOp and LinearOp, the output already has logical order M, N and we are free to place K wherever we want, so I'll place it last. I think this will let us avoid using commitLeafToLogical like is done here:

tv2->commitLeafToLogical();

So in that case we can see how the AxisMapping is standing in for a root->logical reordering. Since there is one for each input operand this feels like another nice use case for read/write/compute domains as suggested by @zasdfgbnm for indexing ldmatrix.

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
// corresponding position of each output axis in either the A or B input.
// Positions are absolute and refer to the noReductions logical domain. NOTE:
// -1 indicates that the axis does not exist, so Broadcast and Reduction
// dimensions should not have position -1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example? For example, what does it mean if we have:

a_axes = {0, 2, 1, 3}
b_axes = {1, 0 , 3, 2}

does it mean a.axis(0) is mapped to b.axis(1)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll put this in a comment but yes, this would mean we would map the following sets:

{a.axis(0), b.axis(1), out.axis(0)}
{a.axis(2), b.axis(0), out.axis(1)}
{a.axis(1), b.axis(3), out.axis(2)}
{a.axis(3), b.axis(2), out.axis(3)}

csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
const mma_utils::MatmulPattern& pattern = patterns.front();

IdModel id_model(&fusion);
const ValGraph& permissive_graph =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use broadcast graph instead of permissive graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually planning to do this, but it means I need to update mma_utils since permissive is still used there, so I will do that in a follow-up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
csrc/ops/arith.cpp Outdated Show resolved Hide resolved
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle merged commit 030c2ba into main Nov 12, 2024
13 of 14 checks passed
@jacobhinkle jacobhinkle deleted the mmaop_no_broadcast branch November 12, 2024 20:43
jacobhinkle added a commit that referenced this pull request Dec 5, 2024
Now that we can define MmaOp with unbroadcasted inputs (see #3391), it
is possible to have ops for which some consumer loop IDs are not used at
all for indexing some consumers.

For example, the output of `MmaOp` has three logical dimensions, M, N,
and K. These are scheduled by spitting, merging, and swizzling, so in
the end the consumer loop domain can contain things like a split of the
N dimension into two other IterDomains. Now if we look at the producer
A, it has logical size [M K], so there is no N dimension at all. Our
current predicate elimination pass places a predicate on this operation
when the N dimension is symbolic and we can't prove that the producer is
parallelized the same way as this consumer in this dimension. However,
since N cannot affect the indexing of the producer A which has no N
dimension, we should skip checking these IterDomains.

This PR does this by performing a BFS from the collection of consumer
root IDs that map to producer logical IDs to the consumer leaf domain.
Only IDs along that path are checked using the existing conditions.

## Detailed example

In the test included in this PR, we have shared memory operand tensors
that are scheduled like this
```
Inputs:
  T0_g___half[ iS0{i0}, iS1{i1} ]
  T1_g___half[ iS2{i3}, iS3{i4} ]
Outputs:
  T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )

%kernel_math {
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] )
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] )
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ),
         T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ))
T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 )
   = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ));
T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
} // %kernel_math

T0_g___half[ iS0{i0}, iS1{i1} ]
 logical domain : (iS0{i0}, iS1{i1})
 contiguity: t t
 loop domain : (iS0{i0}, iS1{i1})
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T1_g___half[ iS2{i3}, iS3{i4} ]
 logical domain : (iS2{i3}, iS3{i4})
 contiguity: t t
 loop domain : (iS2{i3}, iS3{i4})
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})
```
Notice that in `T4_s` that the loop broadcasts `bblockIdx.x33{1}` and
`bS34{256}` are not derived from the logical domain. Instead, they are
actually both the products of a `Split` involving an original "loop
broadcast", although this is not currently shown in
`fusion->printTransforms()`:
```
Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}
```
In the predicate elimination pass with `T4_s` and producer and `T2_l` as
consumer, the consumer ID `iblockIdx.x19{( ceilDiv(i4, 256) )}`
_normally_ would map to a logical broadcast ID in `T4_s`, but with these
loop domain broadcasts we do not have such a mapping. Before this PR
that would cause predication. This PR notices that `iblockIdx.x19{(
ceilDiv(i4, 256) )}` is not actually used for indexing the producer
`T4_s` so we do not need to worry about out-of-bounds accesses in this
dimension.

Without this PR, if we remove the check at
https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37
then we generate the following code:
```c++
__global__ void nvfuser_none_f0_c0_r0_g0(      
    Tensor<__half, 2, 2> T0,                                                          
    Tensor<__half, 2, 2> T1,                                                          
    const __grid_constant__ TensorMap var0,                                           
    const __grid_constant__ TensorMap var1,                                           
    Tensor<__half, 2, 2> T3) {
  // ...
  nvfuser_index_t i4;
  i4 = 256 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i7;
  i7 = 128 * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i19;
  i19 = 64 * ((nvfuser_index_t)threadIdx.y);
  bool b20;
  b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]);

#pragma unroll 1
  for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) {
    nvfuser_index_t i24;
    i24 = 16 * i23;

    // ... load operands ...

    __syncthreads();
    if ((b20 && (i24 < T0.logical_size[0LL]))) {
      asm volatile(
          "{\n"
          "  .reg .pred p0; \n"
          "  setp.ne.b32 p0, %130, 0;\n"
          "  wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */);
    }
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  }
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  // ... epilogue and write outputs ...
}
```
After this PR, the predicate around the `wgmma` call is removed and the
`assertOnWarpOps` check can be restored.

---------

Co-authored-by: Naoya Maruyama <[email protected]>
jacobhinkle added a commit that referenced this pull request Dec 11, 2024
Stacked on #3414 

This PR enables us to inline an MmaOp properly when its inputs are
missing broadcast dimensions. We do this by always allowing inlining
past loop broadcasts or their transforms. For example
```
tv0:
  logical [ iS1{i0} ]
  loop [ iS1{i0} bS5{1} ]
tv1:
  logical [ iS2{i1} ]
  loop [ bS6{1} iS2{i1} ]
tv2 = foo(tv0, tv1)
  logical [ iS3{i0} iS4{i1} ]
```
As long as the operation `foo` properly maps its arguments despite the
missing logical dimensions (as `MmaOp` does as of #3391), then we should
be able to fully inline this case because the loop broadcasts `bS5` and
`bS6` are imaginary in the sense that they don't impact indexing.
jacobhinkle added a commit that referenced this pull request Dec 12, 2024
Stacked on #3410, #3414, and #3416

This simply enables compilation of the test which uses #3391.
jacobhinkle added a commit that referenced this pull request Dec 16, 2024
Stacked on #3410, #3414, and #3416

This simply enables compilation of the test which uses #3391.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable MmaOp to receive unbroadcasted inputs
3 participants