You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is just a tracking issue to collect thoughts about multiple-matmul fusion. I've already begun experimenting with this and plan to implement it in the next couple weeks.
Currently we are able to fuse pointwise prologues and pointwise epilogues surrounding a single matmul into a single kernel. However, we should aim to compute fusions like this in a single kernel as well:
To compute this we have a choice between including two separate main loops to compute the output tiles for D1 and D2 in registers or fusing those two main loops with one another. Using separate main loops might actually be advantageous in this case since we can re-use the smem buffers used to load A1 and B1 for loading A2 and B2.
The above example is a simple case where neither of the operands are reused. However, other patterns are imaginable:
In this case we are re-using the same A operand with two different B operands. Notice that in this case we should fuse the two main loops together, in order to avoid repeatedly loading A in the same kernel.
IO savings
For Case 1 above, assume A1 and A2 are both M by K and B1 and B2 are both K by N. Computing this fusion with three kernels requires an amount $X_1$ of IO given by
$$
\begin{align}
X_{1,a} &= X_{1,b} = T_M N K + T_N M K + M N \\
X_{1,c} &= 2M N \\
X_1 &= X_{1,a} + X_{1,b} + X_{1,c} = 6MN + 2K(T_M N + T_N M)
\end{align}
$$
where $T_M$ is the number of tiles in the M dimension and $T_N$ is the number of tiles in the N dimension.
Now note that when fusing into a single kernel, we need to hold both output tiles in registers. Given that the size of this register buffer is often an active constraint on tile size, we should assume we need to halve the tile size. We could do that by multiplying either $T_M$ or $T_N$ by 2 or more generally by choosing the tile size such that their product is at least $2 T_M T_N$. Assuming we use $2 T_M$, then we can estimate the IO used for a single-kernel approach as $Y_1$ given by
$$
\begin{align}
Y_1 &= 2K(2T_M N + T_N M) + MN
\end{align}
$$
We see that the savings is $X_1-Y_1=4MN - 2KT_M N$, so we see an advantage when $2M > KT_M$ i.e. for small-K cases.
For the second case, where A is reused between the two matmuls, the IO for a three-kernel unfused approach is the same: $X_2=X_1=6MN + 2K(T_MN + T_NM)$. However, in that case since we can fuse the main loop and load A only once, we can get away with
$$
Y_2 = 2(T_M N + T_N M) K +MN
$$
In this case the savings is $X_2-Y_2=4MN - KT_M N$ and we have an advantage whenever $4M > KT_M$. Note that we could also choose to split the tile in half in the N dimension instead which would change this condition.
Not covered: vertical fusion (attention-like patterns)
This does not cover the attention pattern, where the output of one matmul is used as input to another matmul. This is difficult to fuse since our matmul scheduler only produces a single tile per block, but downstream blocks will need to load multiple tiles. This can be addressed, e.g. by memory-efficient attention or flash attention, but that is considerably more complex than the horizontal (i.e. independent) matmul case considered in this issue.
Plan
Update scheduleMatmul to hold vectors of tensors instead of individual tensors for a, b, their cached variants like acw_smem, as well as mma_result and smem_epilogue. Each of these can be multi-valued.
Modify the transform propagations to use these vectors as bounds.
Modify the matmul heuristic (and plugin API) to allow more than a single A and a single B. Consider how best to communicate the structure of the fusion to heuristic plugins, so that we can properly choose parameters. Until we have multiple-matmul aware heuristics, we should just run a heuristic plugin as if a single matmul were being computed, then adjust the tile sizes in an ad hoc manner to ensure memory constraints are satisfied.
Modify compile time checks to accept multiple matmuls under the following constraints (some of which could be lifted later):
All matmuls have the same M, N, and K shapes (different K would be the first of these constraints to try relaxing in the future).
No tensor appears in multiple different roles (this is already checked)
The output of any given matmul is not a producer for an operand in any matmul pattern in the fusion.
The text was updated successfully, but these errors were encountered:
This is just a tracking issue to collect thoughts about multiple-matmul fusion. I've already begun experimenting with this and plan to implement it in the next couple weeks.
Currently we are able to fuse pointwise prologues and pointwise epilogues surrounding a single matmul into a single kernel. However, we should aim to compute fusions like this in a single kernel as well:
To compute this we have a choice between including two separate main loops to compute the output tiles for
D1
andD2
in registers or fusing those two main loops with one another. Using separate main loops might actually be advantageous in this case since we can re-use the smem buffers used to loadA1
andB1
for loadingA2
andB2
.The above example is a simple case where neither of the operands are reused. However, other patterns are imaginable:
In this case we are re-using the same A operand with two different B operands. Notice that in this case we should fuse the two main loops together, in order to avoid repeatedly loading A in the same kernel.
IO savings
For Case 1 above, assume$X_1$ of IO given by
A1
andA2
are both M by K andB1
andB2
are both K by N. Computing this fusion with three kernels requires an amountwhere$T_M$ is the number of tiles in the M dimension and $T_N$ is the number of tiles in the N dimension.
Now note that when fusing into a single kernel, we need to hold both output tiles in registers. Given that the size of this register buffer is often an active constraint on tile size, we should assume we need to halve the tile size. We could do that by multiplying either$T_M$ or $T_N$ by 2 or more generally by choosing the tile size such that their product is at least $2 T_M T_N$ . Assuming we use $2 T_M$ , then we can estimate the IO used for a single-kernel approach as $Y_1$ given by
We see that the savings is$X_1-Y_1=4MN - 2KT_M N$ , so we see an advantage when $2M > KT_M$ i.e. for small-K cases.
For the second case, where A is reused between the two matmuls, the IO for a three-kernel unfused approach is the same:$X_2=X_1=6MN + 2K(T_MN + T_NM)$ . However, in that case since we can fuse the main loop and load A only once, we can get away with
In this case the savings is$X_2-Y_2=4MN - KT_M N$ and we have an advantage whenever $4M > KT_M$ . Note that we could also choose to split the tile in half in the N dimension instead which would change this condition.
Not covered: vertical fusion (attention-like patterns)
This does not cover the attention pattern, where the output of one matmul is used as input to another matmul. This is difficult to fuse since our matmul scheduler only produces a single tile per block, but downstream blocks will need to load multiple tiles. This can be addressed, e.g. by memory-efficient attention or flash attention, but that is considerably more complex than the horizontal (i.e. independent) matmul case considered in this issue.
Plan
scheduleMatmul
to hold vectors of tensors instead of individual tensors fora
,b
, their cached variants likeacw_smem
, as well asmma_result
andsmem_epilogue
. Each of these can be multi-valued.The text was updated successfully, but these errors were encountered: