-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indexing changes for TMA Circular Buffering #2825
Conversation
17daa85
to
e223b8b
Compare
ae8fa45
to
12db3ee
Compare
e223b8b
to
2491171
Compare
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused. Previously there's no handle
function for MBarrierArriveExpectTx
and MBarrierWait
. How were we lowering these ops?
@@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id); | |||
//! indexing special items in shared memory, like mbarrier. | |||
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv); | |||
|
|||
//! Get the uint32_t index of a TensorIndex. This is usually used for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I guess you copy-and-pasted the function above, but why is it called "scalar smem TV"? It seems we already has an address as a Val
, so it doesn't seem to matter if it's scalar or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A single mbarrier
and its token are represented as a size-0 int64_t
TensorView in shared memory, which can be called a scalar, smem TV.
Without circular buffering, Everything is self-contained within the if-statement. if (TDX == 0 && TDY == 0 && TDZ == 0) {
__shared__ int64_t tokens;
__shared__ int64_t mbarrier;
init(mbarrier);
MBarrierArriveExpectTx(mbarrier, data_size, num_threads);
cp.async.bulk(data, mbarrier);
MBarrierWait(mbarrier)
inval(mbarrier);
} |
Where is, for example, |
I'm not sure if you're referring to the circular buffering case, but |
There should be a pass before IndexLowering that inserts |
The |
Ah, I see. Got it. |
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished reading, left a minor feed back.
if (minit->mbarrier()->isA<TensorView>()) { | ||
smem_address_ptr = | ||
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>()); | ||
} else if (minit->mbarrier()->isA<kir::TensorIndex>()) { | ||
smem_address_ptr = lower_utils::u32IndexScalarSmemTv( | ||
minit->mbarrier()->as<kir::TensorIndex>()); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Should we also have a lower_utils::u32IndexScalarSmemTv(Val*)
that does this dispatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps. I was following a rule of 3. Two special cases is okay. Add dispatch if we have a third special case.
GpuLower::current()->ldstMBarrierIndexMap()[new_ldst] = mbarrier; | ||
|
||
GpuLower::current()->propagateExprInfo(ldst, back()); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The circular buffering path will be tested in the tests included in #2773, but do we have an example that exercises this else
branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TMA tests without circular buffering will exercise the else
branch. The code for handleCpAsyncBulkLoad
was moved into the else
branch in this PR.
@@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id); | |||
//! indexing special items in shared memory, like mbarrier. | |||
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv); | |||
|
|||
//! Get the uint32_t index of a TensorIndex. This is usually used for | |||
//! initializing a pipeline of mbarriers. | |||
NVF_API Val* u32IndexScalarSmemTv(kir::TensorIndex* index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this NVF_API
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels unnecessary but this signature is derived from NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try this change in a separate PR.
!build |
b214521
to
1c18058
Compare
## Summary ## This PR adds support for TMA circular buffering. It is stacked on #2824 and #2825. Tracking branch: #2773 ## Description ## - The Pre-Prologue and Post-Epilogue loops are created in the allocation pass. - Pre-Prologue loop allocates share memory and initializes mbarriers, while Post-Epilogue loop invalidates mbarriers. - In the circular buffer pass, `CloneTmaCircularBufferLoopAndInsertSync` clones operations and inserts mbarrier synchronization logic to create the prologue, main, and epilogue for-loops. - Prologue copies only the load operations. `arriveExpectTx` and `arrive` expressions are created for cpAsyncBulk load operations. - Main loop copies the load and computation operations, adds `arriveExpectedTx` and `arrive` for next stage, and calls `mbarrierWait` for current stage. - Epilogue copies only the computation operations and adds `mbarrierWait` for remaining stages in the pipeline. ## Lowering Details ## Description of changes in lowering passes. - `Prologue`, `Main`, and `Epilogue` loops are created by `CloneTmaCircularBufferLoopAndInsertSync` which is a child class of `CircularBufferLoopCloner`. - `PrePrologue` and `PostEpilogue` loops are created in the allocation pass. - The `cuTensorMapEncodeTiled ` restricts the size of each box dimension to be `<= 256`. You will need to launch multiple load operations to load larger tiles. - We only allocate `mbarriers` for each stage, so the `expected_transaction` bytes is multiplied by the number of TMA loads per stage. - The for-loop cloner must account for the nested for-loop structure used to launch multiple TMA loads before adding the `mbarrier_wait` for the stage. ## Loop Structure ## Description of for-loop structure for circular buffering. <details> <summary>Overview Circular Buffer Structure:</summary> ### Pre-prologue loop: ### - Allocate shared memory for mbarriers and mbarrier tokens - Initialize mbarrier for all stages ### Prologue loop: ### - if selected_thread: - Issue cp async bulks for all but last stage ### Main loop: ### - if selected_thread: - Issue next cp async bulk for available stage - All threads wait until tma operation arrives - Copy body without - shared memory allocations - mbarrier_init exprs - mbarrier_inval exprs ### Epilogue loop: ### - All threads wait until tma operation arrives - Copy body without - shared memory allocations - issuing cp async bulk operations - mbarrier_init exprs - mbarrier_inval exprs ### Post-epilogue loop: ### - if selected_thread: - Invalidated mbarrier for all stages </details> <details> <summary>Detailed Pseudo-Code:</summary> ```cpp constexpr int64_t warp_size = 32; bool first_warp = threadIdx.x < warp_size && threadIdx.y == 0 && threadIdx.z == 0; ``` ### Pre-Prologue loop: ### ```cpp __shared__ __mbarrier_t barriers[num_stages]; __shared__ __mbarrier_token_t tokens[num_stages]; for (int64_t loop_index : irange(stages)) { if (first_warp && hopper::electSync()) { mbarrier_init(mbarrier[loop_index], number_of_arrival_threads); } } ``` ### Prologue loop: ### ```cpp // Launch loads for the first stages-1 for (int64_t loop_index : irange(stages-1)) { if (first_warp && hopper::electSync()) { tokens[loop_index] = mbarrier::arriveExpectTx(mbarrier[loop_index]); cpAsyncBulk(mbarriers[loop_index], ...); } else { token[load_stage] = mbarrier::arrive(mbarrier[load_stage]); } } ``` ### Main loop: ### ```cpp // Launch load for last available stage. Wait for the current stage in pipeline. // Repeat for extent - (stages-1) iterations for (int64_t loop_index : irange(N-(stages-1))) { current_stage = loop_index % stage_depth load_stage = (loop_index + (stage_depth - 1)) % stage_depth) if (first_warp && hopper::electSync()) { token[load_stage] = mbarrier::arriveExpectTx(mbarrier[load_stage], expected_transaction_size); cpAsyncBulk(mbarrier[load_stage], ...); } else { token[load_stage] = mbarrier::arrive(mbarrier[load_stage]); } mbarrier::wait(token[current_stage]); // Clone remaining operations } ``` Epilogue loop: ```cpp // Wait for current stage in pipeline. Repeat for remaining iterations in extent. for (int64_t loop_index : irange(N-(stages-1), N)) { current_stage = loop_index % stage_depth mbarrier::wait(token[current_stage]); // Clone remaining operations } ``` ### Post-Epilogue loop: ### ```cpp for (int64_t loop_index : irange(stages)) { if (first_warp && hopper::electSync()) { mbarrier_inval(mbarrier[loop_index]); } } ``` </details> ## Testing Setup ## - 2 to 4 pipeline stages. - (128, 500, 1024) outer dimension. - (128, 1024) inner dimension. 1. Single Dim including Unroll and Unswitch parallelizations. 2. Multiple Dim 3. Pointwise - One Tensor is loaded with TMA circular buffering. The other tensor is loaded with Set circular buffering. 4. PointwiseCpAsync - One Tensor is loaded with TMA circular buffering. The other tensor is loaded with CpAsync circular buffering. This test is currently disabled, but will be fixed by #2339. 5. Reduction 6. InnerPersistent - In this schedule, the output TensorView of the cpAsyncBulk load has a serial iterDomain to the right of computeAt position. A for-loop will launch multiple TMA loads for each pipeline stage. 7. Matmul
Summary
It is the changes to the indexing lowering pass from #2773. It is stacked on #2824.
Tracking Branch: #2773
Details
kir::MBarrierInit
,kir::MBarrierInvalidate
,kir::MBarrierArriveExpectTx
, andkir::MBarrierWait
are modified to handlekir::TensorIndex
.u32IndexScalarSmemTv
is modified to get the shared memory pointer address for akir::TensorIndex
.