From 888f7206e7402b6777eed68f7c6d6f657f74dd81 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 7 Oct 2024 09:35:17 -0700 Subject: [PATCH] Don't initialize TMA output buffer (#3105) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TMA will automatically fill zero, besides, the initialization will race with TMA itself as there is no sync between initialization and TMA. Matmul perf after enabling circular buffering: ```markdown Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- --------- --------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 57.3 3218718 1 3218718.0 3218718.0 3218718 3218718 0.0 ::nvfuser_none_f0_c0_r0_g0(::Tensor<::__half, (int)3, (int)3>, … 12.5 700153 1 700153.0 700153.0 700153 700153 0.0 nvjet_hsh_192x192_64x3_2x1_v_bz_coopB_NTN ``` --- csrc/device_lower/pass/allocation.cpp | 4 ++-- tests/cpp/test_matmul.cpp | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 346c59a1a08..7a65be53d9b 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -508,10 +508,10 @@ class AllocationInserter : public kir::ExprMutator { init = default_val; } - if (ir_utils::isCpAsyncOp(expr)) { + if (ir_utils::isCpAsyncOp(expr) || ir_utils::isCpAsyncBulk(expr)) { NVF_CHECK( init == nullptr || init->isZero(), - "cp.async initialized with non-zero is not supported"); + "cp.async and cp.async.bulk initialized with non-zero is not supported"); // cp.async will automatically fill zero when out of bound init = nullptr; } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 9f4fd14cfa0..239843f52eb 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3624,9 +3624,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { inlineMost(); - // TODO: looks like this test will hang if I enable this - // tv0c->circularBuffer(/*number_of_stages=*/4); - // tv1c->circularBuffer(/*number_of_stages=*/4); + tv0c->circularBuffer(/*number_of_stages=*/4); + tv1c->circularBuffer(/*number_of_stages=*/4); auto inputs = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));