-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add warp specialization as a circular buffering type (#3511)
This PR adds warp specialization as a new type of circular buffering. Today, we already support pipelined circular buffer, and optionally, we could choose whether we want to use block-sync or mbarrier for handling WAR hazards. If we choose to use mbarrier for handling WAR harzard, then we will generate kernel like below: ```python # Mark buffer[i] as empty and ready to be loaded for i in range(prefetch + 1): arrive(war_mbarrier[i]) # Prologue: thanks to the previous arrives, all the loads will just go through and no wait needed for i in range(prefetch): wait war_mbarrier[i] arrive-expect-tx raw_mbarrier[i] load data[i] into buffer[i] # Main loop: for i in range(data.size - prefetch): if elect-sync: wait war_mbarrier[(i + prefetch) % stage] arrive-expect-tx raw_mbarrier[(i + prefetch) % stage] load data[i + prefetch] to buffer[(i + prefetch) % stage] wait raw_mbarrier[i % stage] mma on buffer[i % stage] for data[i] wait until there are at most stage - prefetch - 1 pending mma arrive war_mbarrier[(i + prefetch + 1) % stage] # Epilogue for i in range(data.size - prefetch, data.size): wait raw_mbarrier[i % stage] mma on buffer[i % stage] for data[i] wait until there are at most 0 pending mma write result back to gmem ``` The kernel above has the following problems: 1. The MMA loop is not clean. There is one thread doing an extra work of loading, while other threads in the warp groups just waiting this one thread to finish. (Note that mma is a warp-group collective, so all threads in the warp group must arrive that instruction for it to start). Ideally, we should have a for loop with only mma, and nothing else. Having extra instructions could increase the latency. 2. There is a false dependency between the loading of `data[i + prefetch]` and the computing of `data[i]`. These two things are not dealing with the same data, so in theory, they should not depend on each other, and whoever gets its mbarrier cleared first should go first. However, just because codes are executed from top to bottom, the mma has to wait until the load is issued. This further increases latency. With the above problem observed, it is naturally to ask: why not use different warps for load and compute? The load code and the compute code in the main loop are completely independent, and both the RAW and WAR are handled by mbarrier, which is on smem and accessible across the entire CTA, so all the preconditions for warp specialization are mature, and we just need to put different IR nodes into different places. This PR adds warp specialization. The generated code is similar to the pipelined code that uses mbarrier for WAR, but actually simpler. The code looks like below (assuming doing warp specialization on TIDy): ```python if threadIdx.y == blockDim.y - 1: # If we use warp specialization on TIDy, then the blockDim.y of the # kernel will be (whatever_value_inferred_from_schedule + 1), and the # last threadIdx.y will be used as load warp for i in range(data.size): wait war_mbarrier[i % stage] load data[i] to buffer[i % stage] else: # Every threadIdx.y other than the last will be used for compute for i in range(prefetch + 1): arrive war_mbarrier[i % stage] for i in range(data.size): wait raw_mbarrier[i % stage] compute buffer[i % stage] wait until there are at most stage - prefetch - 1 pending mma arrive war_mbarrier[(i + prefetch + 1) % stage] ``` This new way of doing circular buffering is intended to be computation-agnostic, it should work on whatever kernel we are scheduling, instead of just matmuls. But note that today, there are some strong limitations that makes it less applicable: 1. The computation can not have hardcoded `blockDim` in it. So block reduction will not work. I believe this will be easy to fix, but it is beyond the scope of this PR. 2. Because the warp-specialized parallel type will no longer be exact, there will be thread predicates generated for it. Predication elimination is not yet smart enough to know that this is in the compute warp, so already predicated and not need to predicate it again. This limitation also means, the computation can not be tensor core operations (`MmaOp`), so this PR actually does not work with matmul. Besides the above limitation, I believe this new circular buffer type is pretty generic, and in the future, I believe we should be able to try it with TMA in perf tuning. --------- Co-authored-by: Ryan Spring <[email protected]>
- Loading branch information
Showing
10 changed files
with
268 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.