Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horizontal fusion of multiple matmuls in a single kernel #2454

Open
jacobhinkle opened this issue Jun 25, 2024 · 0 comments
Open

Horizontal fusion of multiple matmuls in a single kernel #2454

jacobhinkle opened this issue Jun 25, 2024 · 0 comments
Assignees
Labels

Comments

@jacobhinkle
Copy link
Collaborator

This is just a tracking issue to collect thoughts about multiple-matmul fusion. I've already begun experimenting with this and plan to implement it in the next couple weeks.

Currently we are able to fuse pointwise prologues and pointwise epilogues surrounding a single matmul into a single kernel. However, we should aim to compute fusions like this in a single kernel as well:

# Case 1
D1 = matmul(A1, B1);
D2 = matmul(A2, B2);
addOutput(mul(D1, D2));

To compute this we have a choice between including two separate main loops to compute the output tiles for D1 and D2 in registers or fusing those two main loops with one another. Using separate main loops might actually be advantageous in this case since we can re-use the smem buffers used to load A1 and B1 for loading A2 and B2.

The above example is a simple case where neither of the operands are reused. However, other patterns are imaginable:

# Case 2
D1 = matmul(A, B1);
D2 = matmul(A, B2);
addOutput(mul(D1, D2));

In this case we are re-using the same A operand with two different B operands. Notice that in this case we should fuse the two main loops together, in order to avoid repeatedly loading A in the same kernel.

IO savings

For Case 1 above, assume A1 and A2 are both M by K and B1 and B2 are both K by N. Computing this fusion with three kernels requires an amount $X_1$ of IO given by

$$ \begin{align} X_{1,a} &= X_{1,b} = T_M N K + T_N M K + M N \\ X_{1,c} &= 2M N \\ X_1 &= X_{1,a} + X_{1,b} + X_{1,c} = 6MN + 2K(T_M N + T_N M) \end{align} $$

where $T_M$ is the number of tiles in the M dimension and $T_N$ is the number of tiles in the N dimension.

Now note that when fusing into a single kernel, we need to hold both output tiles in registers. Given that the size of this register buffer is often an active constraint on tile size, we should assume we need to halve the tile size. We could do that by multiplying either $T_M$ or $T_N$ by 2 or more generally by choosing the tile size such that their product is at least $2 T_M T_N$. Assuming we use $2 T_M$, then we can estimate the IO used for a single-kernel approach as $Y_1$ given by

$$ \begin{align} Y_1 &= 2K(2T_M N + T_N M) + MN \end{align} $$

We see that the savings is $X_1-Y_1=4MN - 2KT_M N$, so we see an advantage when $2M > KT_M$ i.e. for small-K cases.

For the second case, where A is reused between the two matmuls, the IO for a three-kernel unfused approach is the same: $X_2=X_1=6MN + 2K(T_MN + T_NM)$. However, in that case since we can fuse the main loop and load A only once, we can get away with

$$ Y_2 = 2(T_M N + T_N M) K +MN $$

In this case the savings is $X_2-Y_2=4MN - KT_M N$ and we have an advantage whenever $4M > KT_M$. Note that we could also choose to split the tile in half in the N dimension instead which would change this condition.

Not covered: vertical fusion (attention-like patterns)

This does not cover the attention pattern, where the output of one matmul is used as input to another matmul. This is difficult to fuse since our matmul scheduler only produces a single tile per block, but downstream blocks will need to load multiple tiles. This can be addressed, e.g. by memory-efficient attention or flash attention, but that is considerably more complex than the horizontal (i.e. independent) matmul case considered in this issue.

Plan

  1. Update scheduleMatmul to hold vectors of tensors instead of individual tensors for a, b, their cached variants like acw_smem, as well as mma_result and smem_epilogue. Each of these can be multi-valued.
  2. Modify the transform propagations to use these vectors as bounds.
  3. Modify the matmul heuristic (and plugin API) to allow more than a single A and a single B. Consider how best to communicate the structure of the fusion to heuristic plugins, so that we can properly choose parameters. Until we have multiple-matmul aware heuristics, we should just run a heuristic plugin as if a single matmul were being computed, then adjust the tile sizes in an ad hoc manner to ensure memory constraints are satisfied.
  4. Modify compile time checks to accept multiple matmuls under the following constraints (some of which could be lifted later):
  • All matmuls have the same M, N, and K shapes (different K would be the first of these constraints to try relaxing in the future).
  • No tensor appears in multiple different roles (this is already checked)
  • The output of any given matmul is not a producer for an operand in any matmul pattern in the fusion.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants