Skip to content

Commit

Permalink
Don't initialize TMA output buffer (#3105)
Browse files Browse the repository at this point in the history
TMA will automatically fill zero, besides, the initialization will race
with TMA itself as there is no sync between initialization and TMA.

Matmul perf after enabling circular buffering:
```markdown
 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name

 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     57.3          3218718          1  3218718.0  3218718.0   3218718   3218718          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
     12.5           700153          1   700153.0   700153.0    700153    700153          0.0  nvjet_hsh_192x192_64x3_2x1_v_bz_coopB_NTN
```
  • Loading branch information
zasdfgbnm authored Oct 7, 2024
1 parent 2b9e9d6 commit 888f720
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class AllocationInserter : public kir::ExprMutator {
init = default_val;
}

if (ir_utils::isCpAsyncOp(expr)) {
if (ir_utils::isCpAsyncOp(expr) || ir_utils::isCpAsyncBulk(expr)) {
NVF_CHECK(
init == nullptr || init->isZero(),
"cp.async initialized with non-zero is not supported");
"cp.async and cp.async.bulk initialized with non-zero is not supported");
// cp.async will automatically fill zero when out of bound
init = nullptr;
}
Expand Down
5 changes: 2 additions & 3 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3624,9 +3624,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {

inlineMost();

// TODO: looks like this test will hang if I enable this
// tv0c->circularBuffer(/*number_of_stages=*/4);
// tv1c->circularBuffer(/*number_of_stages=*/4);
tv0c->circularBuffer(/*number_of_stages=*/4);
tv1c->circularBuffer(/*number_of_stages=*/4);

auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
Expand Down

0 comments on commit 888f720

Please sign in to comment.