Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MmaOp to receive unbroadcasted inputs #3372

Closed
jacobhinkle opened this issue Nov 7, 2024 · 0 comments · Fixed by #3391
Closed

Enable MmaOp to receive unbroadcasted inputs #3372

jacobhinkle opened this issue Nov 7, 2024 · 0 comments · Fixed by #3391
Assignees
Labels

Comments

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 7, 2024

This is a proposal to enable MmaOp to receive inputs shaped like [M, K] and [N, K] instead of [M, 1, K] and [1, N, K].

This is an alternative to #3366.

Motivation

Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand. For example:

tv_a [ iS0{M}, bS1{1}, iS2{K} ]
tv_b [ bS3{1}, iS4{N}, iS5{K} ]
tv_c [ iS6{M}, iS7{N}, rS8{K} ] = fusedMultiplySum(tv_a, tv_b, axes={-1})

This lets us use the default exact domain mapping between operands and MmaOp output to determine the following groups {0, 3, 6}, {1, 4, 7}, {2, 5, 8}.

However, it means that if we are translating a Fusion that has MatmulOp or LinearOp to use MmaOp, we need to introduce BroadcastOp nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.

Proposed Approach

In order to avoid needing BroadcastOp when our segment inputs do not already have broadcasts, we need to handle cases like this:

(MatmulOp translation)
tv_a [ iS0{M}, iS1{K} ]
tv_b [ iS2{K}, iS3{N} ]
tv_c [ iS4{M}, iS5{N}, rS6{K} ] = fusedMultiplySum(tv_a, tv_b, ??)

We can no longer just specify some numbered axes to reduce, since the inputs do not have the same number of axes as the output. And even if we can specify it in the op, the IR node will need to hold some more information so that we can perform Exact and Broadcast mapping of IterDomains across the MmaOp.

We now need to specify the axes to reduce as well as the axis correspondences in the inputs. One possibility is to specify the position of the corresponding input axis for each output axis as an axis mapping:

tv_c = fusedMultiplySum(
    tv_a,
    tv_b,
    /*init=*/nullptr,
    /*axis_mapping=*/{/*a_axes=*/{0, -1, 1}, /*b_axes=*/{-1, 1, 0}});

I know this looks a little verbose but remember that we're not going to be creating these by hand very often as our main path will be to receive MatmulOp and LinearOp from Thunder and that interface will not change.

I propose to do the following:

  • Add attributes to the MmaOp indicating the roles of dimensions in the tv_a and tv_b inputs.
  • Add a special case in PairwiseLogicalDomainMap that will map the output domains to domains in the inputs using the map above. This is similar to what we do for SdpaFwdOp and SdpaBwdOp currently.
  • Update mma_utils::MatmulPattern::translateToMmaOp to skip inserting broadcasts and use this interface instead.
  • Update the Ampere matmul scheduler to not assume there is a broadcast M or N dimension in the ab and bb tensors.

Note that I also plan to keep the current interface for fusedMultiplySum available, so that we can use broadcasted inputs if we want to. The only caveat with keeping that old behavior around is that it might complicate the changes to the Ampere scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant