Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing changes for TMA Circular Buffering #2825

Merged
merged 4 commits into from
Sep 5, 2024
Merged

Indexing changes for TMA Circular Buffering #2825

merged 4 commits into from
Sep 5, 2024

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Aug 21, 2024

Summary

It is the changes to the indexing lowering pass from #2773. It is stacked on #2824.
Tracking Branch: #2773

Details

  • In the circular buffering pass, we manually index the mbarriers and tokens using the index of the prologue, main, and epilogue loops.
for (int index : c10::irange(fl->extent()) {
  int stage = index % number_of_pipeline_stages;
  mbarrier_t current_stage_mbarrier = mbarriers[stage];  // represented with kir::TensorIndex

  int next_stage = (index + number_of_stages - 1) % number_of_pipeline_stages;
  mbarrier_t next_stage_mbarrier = mbarriers[next_stage];  // represented with kir::TensorIndex
}
  • The handle functions for kir::MBarrierInit, kir::MBarrierInvalidate, kir::MBarrierArriveExpectTx, and kir::MBarrierWait are modified to handle kir::TensorIndex.
  • u32IndexScalarSmemTv is modified to get the shared memory pointer address for a kir::TensorIndex.

@rdspring1 rdspring1 changed the base branch from main to tma_cb_alloc August 22, 2024 00:31
@rdspring1
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused. Previously there's no handle function for MBarrierArriveExpectTx and MBarrierWait. How were we lowering these ops?

@@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id);
//! indexing special items in shared memory, like mbarrier.
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);

//! Get the uint32_t index of a TensorIndex. This is usually used for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess you copy-and-pasted the function above, but why is it called "scalar smem TV"? It seems we already has an address as a Val, so it doesn't seem to matter if it's scalar or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single mbarrier and its token are represented as a size-0 int64_t TensorView in shared memory, which can be called a scalar, smem TV.

@rdspring1
Copy link
Collaborator Author

Without circular buffering, void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) adds MBarrierArriveExpectTx and MBarrierWait. See https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/index.cpp#L1421-L1446

Everything is self-contained within the if-statement.

  if (TDX == 0 && TDY == 0 && TDZ == 0) {
    __shared__ int64_t tokens;
    __shared__ int64_t mbarrier;
    init(mbarrier);
    MBarrierArriveExpectTx(mbarrier, data_size, num_threads);
    cp.async.bulk(data, mbarrier);
    MBarrierWait(mbarrier)
    inval(mbarrier);
 }

@naoyam
Copy link
Collaborator

naoyam commented Aug 22, 2024

Where is, for example, kir::MBarrierArriveExpectTx first inserted?

@rdspring1
Copy link
Collaborator Author

I'm not sure if you're referring to the circular buffering case, but kir::MBarrierArriveExpectTx is inserted at https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/index.cpp#L1434-L1435 for regular cpAsyncBulk load.

@naoyam
Copy link
Collaborator

naoyam commented Aug 23, 2024

I'm not sure if you're referring to the circular buffering case, but kir::MBarrierArriveExpectTx is inserted at https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/index.cpp#L1434-L1435 for regular cpAsyncBulk load.

There should be a pass before IndexLowering that inserts kir::MBarrierArriveExpectTx, right? Otherwise, I don't understand why we would need IndexLowering::hanlde(kir::MBarrierArriveExpectTx*). That handler should never be called if there's already MBarrierArriveExpectTx IR nodes.

@rdspring1
Copy link
Collaborator Author

The kir::MBarrierArriveExpectTx nodes are inserted in the circular buffering pass. Those changes are in the last PR #2833.
See https://github.com/NVIDIA/Fuser/pull/2833/files#diff-804c3cd07d9989909d30d986a6e04a00157713200cb74b255990d88cdc15ce44R663.

@naoyam
Copy link
Collaborator

naoyam commented Aug 23, 2024

The kir::MBarrierArriveExpectTx nodes are inserted in the circular buffering pass. Those changes are in the last PR #2833. See https://github.com/NVIDIA/Fuser/pull/2833/files#diff-804c3cd07d9989909d30d986a6e04a00157713200cb74b255990d88cdc15ce44R663.

Ah, I see. Got it.

@rdspring1
Copy link
Collaborator Author

!build

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished reading, left a minor feed back.

Comment on lines +1409 to +1415
if (minit->mbarrier()->isA<TensorView>()) {
smem_address_ptr =
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>());
} else if (minit->mbarrier()->isA<kir::TensorIndex>()) {
smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
minit->mbarrier()->as<kir::TensorIndex>());
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should we also have a lower_utils::u32IndexScalarSmemTv(Val*) that does this dispatch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps. I was following a rule of 3. Two special cases is okay. Add dispatch if we have a third special case.

GpuLower::current()->ldstMBarrierIndexMap()[new_ldst] = mbarrier;

GpuLower::current()->propagateExprInfo(ldst, back());
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The circular buffering path will be tested in the tests included in #2773, but do we have an example that exercises this else branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TMA tests without circular buffering will exercise the else branch. The code for handleCpAsyncBulkLoad was moved into the else branch in this PR.

@@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id);
//! indexing special items in shared memory, like mbarrier.
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);

//! Get the uint32_t index of a TensorIndex. This is usually used for
//! initializing a pipeline of mbarriers.
NVF_API Val* u32IndexScalarSmemTv(kir::TensorIndex* index);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this NVF_API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels unnecessary but this signature is derived from NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try this change in a separate PR.

csrc/device_lower/pass/index.cpp Outdated Show resolved Hide resolved
@rdspring1
Copy link
Collaborator Author

!build

Base automatically changed from tma_cb_alloc to main September 4, 2024 16:42
@rdspring1 rdspring1 merged commit 752c0fe into main Sep 5, 2024
36 checks passed
@rdspring1 rdspring1 deleted the tma_cb_index branch September 5, 2024 02:22
rdspring1 added a commit that referenced this pull request Sep 29, 2024
## Summary ##
This PR adds support for TMA circular buffering. It is stacked on #2824
and #2825.
Tracking branch: #2773

## Description ##
- The Pre-Prologue and Post-Epilogue loops are created in the allocation
pass.
- Pre-Prologue loop allocates share memory and initializes mbarriers,
while Post-Epilogue loop invalidates mbarriers.
- In the circular buffer pass, `CloneTmaCircularBufferLoopAndInsertSync`
clones operations and inserts mbarrier synchronization logic to create
the prologue, main, and epilogue for-loops.
- Prologue copies only the load operations. `arriveExpectTx` and
`arrive` expressions are created for cpAsyncBulk load operations.
- Main loop copies the load and computation operations, adds
`arriveExpectedTx` and `arrive` for next stage, and calls `mbarrierWait`
for current stage.
- Epilogue copies only the computation operations and adds
`mbarrierWait` for remaining stages in the pipeline.


## Lowering Details ##
Description of changes in lowering passes.
- `Prologue`, `Main`, and `Epilogue` loops are created by
`CloneTmaCircularBufferLoopAndInsertSync` which is a child class of
`CircularBufferLoopCloner`.
- `PrePrologue` and `PostEpilogue` loops are created in the allocation
pass.
- The `cuTensorMapEncodeTiled ` restricts the size of each box dimension
to be `<= 256`. You will need to launch multiple load operations to load
larger tiles.
- We only allocate `mbarriers` for each stage, so the
`expected_transaction` bytes is multiplied by the number of TMA loads
per stage.
- The for-loop cloner must account for the nested for-loop structure
used to launch multiple TMA loads before adding the `mbarrier_wait` for
the stage.


## Loop Structure ##
Description of for-loop structure for circular buffering.

<details>
<summary>Overview Circular Buffer Structure:</summary>

### Pre-prologue loop: ###
- Allocate shared memory for mbarriers and mbarrier tokens
- Initialize mbarrier for all stages

### Prologue loop: ###
- if selected_thread:
  - Issue cp async bulks for all but last stage

### Main loop: ###
- if selected_thread:
  - Issue next cp async bulk for available stage
- All threads wait until tma operation arrives
- Copy body without
  - shared memory allocations
  - mbarrier_init exprs
  - mbarrier_inval exprs

### Epilogue loop: ###
- All threads wait until tma operation arrives
- Copy body without
  - shared memory allocations
  - issuing cp async bulk operations
  - mbarrier_init exprs
  - mbarrier_inval exprs

### Post-epilogue loop: ###
- if selected_thread:
  - Invalidated mbarrier for all stages
</details>

<details>
<summary>Detailed Pseudo-Code:</summary>

```cpp
constexpr int64_t warp_size = 32;
bool first_warp = threadIdx.x < warp_size && threadIdx.y == 0 && threadIdx.z == 0;
```

### Pre-Prologue loop: ###
```cpp
__shared__ __mbarrier_t barriers[num_stages];
__shared__ __mbarrier_token_t tokens[num_stages];
for (int64_t loop_index : irange(stages)) {
  if (first_warp && hopper::electSync()) {
    mbarrier_init(mbarrier[loop_index], number_of_arrival_threads);
  }
}
```

### Prologue loop: ###
```cpp
// Launch loads for the first stages-1
for (int64_t loop_index : irange(stages-1)) {
  if (first_warp && hopper::electSync()) {
    tokens[loop_index] = mbarrier::arriveExpectTx(mbarrier[loop_index]);
    cpAsyncBulk(mbarriers[loop_index], ...);
  } else {
    token[load_stage] = mbarrier::arrive(mbarrier[load_stage]);
  }
}
```

### Main loop: ###
```cpp
// Launch load for last available stage. Wait for the current stage in pipeline.
// Repeat for extent - (stages-1) iterations
for (int64_t loop_index : irange(N-(stages-1))) {
  current_stage = loop_index % stage_depth
  load_stage = (loop_index + (stage_depth - 1)) % stage_depth)
  if (first_warp && hopper::electSync()) {
    token[load_stage] =
      mbarrier::arriveExpectTx(mbarrier[load_stage], expected_transaction_size);
    cpAsyncBulk(mbarrier[load_stage], ...);
  } else {
    token[load_stage] = mbarrier::arrive(mbarrier[load_stage]);
  }
  mbarrier::wait(token[current_stage]);

  // Clone remaining operations
}
```

Epilogue loop:
```cpp
// Wait for current stage in pipeline. Repeat for remaining iterations in extent.
for (int64_t loop_index : irange(N-(stages-1), N)) {
  current_stage = loop_index % stage_depth
  mbarrier::wait(token[current_stage]);

  // Clone remaining operations
}
```

### Post-Epilogue loop: ###
```cpp
for (int64_t loop_index : irange(stages)) {
  if (first_warp && hopper::electSync()) {
    mbarrier_inval(mbarrier[loop_index]);
  }
}
```
</details>

## Testing Setup ##
- 2 to 4 pipeline stages.
- (128, 500, 1024) outer dimension.
- (128, 1024) inner dimension.

1. Single Dim including Unroll and Unswitch parallelizations.
2. Multiple Dim
3. Pointwise 
- One Tensor is loaded with TMA circular buffering. The other tensor is
loaded with Set circular buffering.
4. PointwiseCpAsync 
- One Tensor is loaded with TMA circular buffering. The other tensor is
loaded with CpAsync circular buffering. This test is currently disabled,
but will be fixed by #2339.
5. Reduction
6. InnerPersistent 
- In this schedule, the output TensorView of the cpAsyncBulk load has a
serial iterDomain to the right of computeAt position. A for-loop will
launch multiple TMA loads for each pipeline stage.
7. Matmul
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants