-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tracking] TMA Circular Buffering #2773
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This was referenced Aug 21, 2024
rdspring1
added a commit
that referenced
this pull request
Aug 27, 2024
This PR refactors `CircularBufferLoopCloner` to avoid clang-tidy issues in #2773. - Track cloned for loop instead of its Scope - Add virtual methods `processExpr` and `processForLoop` for `TmaCircularBufferLoopCloner` to override. Details: ``` Error (CLANGTIDY) [bugprone-parent-virtual-call,-warnings-as-errors] qualified name 'kir::IrVisitor::dispatch' refers to a member overridden in subclass; did you mean 'nvfuser::CircularBufferLoopCloner'? ```
rdspring1
added a commit
that referenced
this pull request
Sep 4, 2024
## Summary ## It is the changes to the allocation lowering pass from #2773. ## Details ## ### GpuLower ### - `ldst_mbarrier_token_map_` maps `LoadStoreOp` to mbarrier tokens, which are represented as `TensorView` of number of pipeline stages. - `mbarrier_token_smem_alloc_set_` tracks the `kir::Allocate` expressions for the mbarriers and their tokens. - `ldst_mbarrier_index_map_` maps the cloned `LoadStoreOp` in the prologue and main loops to their indexed mbarrier. ### Allocation ### - In the allocation pass, create shared memory allocations and operations around `LoadStoreOp` expression. ```cpp // Created tokens, mbarriers, init, and inval operations in allocation pass. for (circular_buffer_loop) { __shared__ int64_t tokens[num_stages]; __shared__ int64_t mbarrier[num_stages]; init(mbarrier); cp.async.bulk(data, mbarrier); inval(mbarrier); } ``` ## AliasMemory ## - The mbarrier and its token are mapped together. The token is the mbarrier state of the last phase. For simplicity, mark token liveness when mbarrier is initialized and invalidated. - Apply `markWrite` for mbarrier and its token when the expression is `MBarrierInit` - Apply `markRead` for mbarrier and its token when the expression is `MBarrierInvalidate`
rdspring1
added a commit
that referenced
this pull request
Sep 5, 2024
## Summary ## It is the changes to the indexing lowering pass from #2773. It is stacked on #2824. Tracking Branch: #2773 ## Details ## - In the circular buffering pass, we manually index the mbarriers and tokens using the index of the prologue, main, and epilogue loops. ```cpp for (int index : c10::irange(fl->extent()) { int stage = index % number_of_pipeline_stages; mbarrier_t current_stage_mbarrier = mbarriers[stage]; // represented with kir::TensorIndex int next_stage = (index + number_of_stages - 1) % number_of_pipeline_stages; mbarrier_t next_stage_mbarrier = mbarriers[next_stage]; // represented with kir::TensorIndex } ``` - The handle functions for `kir::MBarrierInit`, `kir::MBarrierInvalidate`, `kir::MBarrierArriveExpectTx`, and `kir::MBarrierWait` are modified to handle `kir::TensorIndex`. - `u32IndexScalarSmemTv` is modified to get the shared memory pointer address for a `kir::TensorIndex`.
* Add support for Hopper::electSync * Create ElectSync PredicateType * Make mbarrier synchronous * mbarrier waits for all threads in CTA * All threads issues arriveExpectTx to get mbarrier_token
!build |
!build |
rdspring1
force-pushed
the
tma_circular_buffering
branch
from
September 8, 2024 23:11
7f9e5f3
to
5e38511
Compare
rdspring1
added a commit
that referenced
this pull request
Sep 29, 2024
## Summary ## This PR adds support for TMA circular buffering. It is stacked on #2824 and #2825. Tracking branch: #2773 ## Description ## - The Pre-Prologue and Post-Epilogue loops are created in the allocation pass. - Pre-Prologue loop allocates share memory and initializes mbarriers, while Post-Epilogue loop invalidates mbarriers. - In the circular buffer pass, `CloneTmaCircularBufferLoopAndInsertSync` clones operations and inserts mbarrier synchronization logic to create the prologue, main, and epilogue for-loops. - Prologue copies only the load operations. `arriveExpectTx` and `arrive` expressions are created for cpAsyncBulk load operations. - Main loop copies the load and computation operations, adds `arriveExpectedTx` and `arrive` for next stage, and calls `mbarrierWait` for current stage. - Epilogue copies only the computation operations and adds `mbarrierWait` for remaining stages in the pipeline. ## Lowering Details ## Description of changes in lowering passes. - `Prologue`, `Main`, and `Epilogue` loops are created by `CloneTmaCircularBufferLoopAndInsertSync` which is a child class of `CircularBufferLoopCloner`. - `PrePrologue` and `PostEpilogue` loops are created in the allocation pass. - The `cuTensorMapEncodeTiled ` restricts the size of each box dimension to be `<= 256`. You will need to launch multiple load operations to load larger tiles. - We only allocate `mbarriers` for each stage, so the `expected_transaction` bytes is multiplied by the number of TMA loads per stage. - The for-loop cloner must account for the nested for-loop structure used to launch multiple TMA loads before adding the `mbarrier_wait` for the stage. ## Loop Structure ## Description of for-loop structure for circular buffering. <details> <summary>Overview Circular Buffer Structure:</summary> ### Pre-prologue loop: ### - Allocate shared memory for mbarriers and mbarrier tokens - Initialize mbarrier for all stages ### Prologue loop: ### - if selected_thread: - Issue cp async bulks for all but last stage ### Main loop: ### - if selected_thread: - Issue next cp async bulk for available stage - All threads wait until tma operation arrives - Copy body without - shared memory allocations - mbarrier_init exprs - mbarrier_inval exprs ### Epilogue loop: ### - All threads wait until tma operation arrives - Copy body without - shared memory allocations - issuing cp async bulk operations - mbarrier_init exprs - mbarrier_inval exprs ### Post-epilogue loop: ### - if selected_thread: - Invalidated mbarrier for all stages </details> <details> <summary>Detailed Pseudo-Code:</summary> ```cpp constexpr int64_t warp_size = 32; bool first_warp = threadIdx.x < warp_size && threadIdx.y == 0 && threadIdx.z == 0; ``` ### Pre-Prologue loop: ### ```cpp __shared__ __mbarrier_t barriers[num_stages]; __shared__ __mbarrier_token_t tokens[num_stages]; for (int64_t loop_index : irange(stages)) { if (first_warp && hopper::electSync()) { mbarrier_init(mbarrier[loop_index], number_of_arrival_threads); } } ``` ### Prologue loop: ### ```cpp // Launch loads for the first stages-1 for (int64_t loop_index : irange(stages-1)) { if (first_warp && hopper::electSync()) { tokens[loop_index] = mbarrier::arriveExpectTx(mbarrier[loop_index]); cpAsyncBulk(mbarriers[loop_index], ...); } else { token[load_stage] = mbarrier::arrive(mbarrier[load_stage]); } } ``` ### Main loop: ### ```cpp // Launch load for last available stage. Wait for the current stage in pipeline. // Repeat for extent - (stages-1) iterations for (int64_t loop_index : irange(N-(stages-1))) { current_stage = loop_index % stage_depth load_stage = (loop_index + (stage_depth - 1)) % stage_depth) if (first_warp && hopper::electSync()) { token[load_stage] = mbarrier::arriveExpectTx(mbarrier[load_stage], expected_transaction_size); cpAsyncBulk(mbarrier[load_stage], ...); } else { token[load_stage] = mbarrier::arrive(mbarrier[load_stage]); } mbarrier::wait(token[current_stage]); // Clone remaining operations } ``` Epilogue loop: ```cpp // Wait for current stage in pipeline. Repeat for remaining iterations in extent. for (int64_t loop_index : irange(N-(stages-1), N)) { current_stage = loop_index % stage_depth mbarrier::wait(token[current_stage]); // Clone remaining operations } ``` ### Post-Epilogue loop: ### ```cpp for (int64_t loop_index : irange(stages)) { if (first_warp && hopper::electSync()) { mbarrier_inval(mbarrier[loop_index]); } } ``` </details> ## Testing Setup ## - 2 to 4 pipeline stages. - (128, 500, 1024) outer dimension. - (128, 1024) inner dimension. 1. Single Dim including Unroll and Unswitch parallelizations. 2. Multiple Dim 3. Pointwise - One Tensor is loaded with TMA circular buffering. The other tensor is loaded with Set circular buffering. 4. PointwiseCpAsync - One Tensor is loaded with TMA circular buffering. The other tensor is loaded with CpAsync circular buffering. This test is currently disabled, but will be fixed by #2339. 5. Reduction 6. InnerPersistent - In this schedule, the output TensorView of the cpAsyncBulk load has a serial iterDomain to the right of computeAt position. A for-loop will launch multiple TMA loads for each pipeline stage. 7. Matmul
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for TMA circular buffering. It is based on #1484.
Description
LoadStoreOp
expression.arrive_expected_tx
for next stage andmbarrier_wait
for current stage.mbarrier_wait
for remaining stages in the pipeline.Lowering Details
Description of changes in lowering passes.
Prologue
,Main
, andEpilogue
loops are created byTmaCircularBufferLoopCloner
which is a child class ofCircularBufferLoopCloner
.PrePrologue
andPostEpilogue
loops are created bycreateCpAsyncBulkFixtures
.cuTensorMapEncodeTiled
restricts the size of each box dimension to be<= 256
. You need to launch multiple load operations to load larger tiles.mbarriers
for each stage, so theexpected_transaction
bytes is multiplied by the number of TMA loads per stage.mbarrier_wait
for the stage.Allocation Pass
Loop Structure
Description of for-loop structure for circular buffering.
Overview Circular Buffer Structure:
Pre-prologue loop:
Prologue loop:
Main loop:
Epilogue loop:
Post-epilogue loop:
Detailed Pseudo-Code:
Pre-Prologue loop:
Prologue loop:
Main loop:
Epilogue loop:
Post-Epilogue loop:
Testing
Future PRs
TDX == 0 && TDY == 0 && TDZ == 0
withHopper::elect_sync
PTX instruction.LoadStore
operations add additionalTDY == 0
predicates that conflict withHopper::elect_sync
.