diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 346c59a1a08..7a65be53d9b 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -508,10 +508,10 @@ class AllocationInserter : public kir::ExprMutator { init = default_val; } - if (ir_utils::isCpAsyncOp(expr)) { + if (ir_utils::isCpAsyncOp(expr) || ir_utils::isCpAsyncBulk(expr)) { NVF_CHECK( init == nullptr || init->isZero(), - "cp.async initialized with non-zero is not supported"); + "cp.async and cp.async.bulk initialized with non-zero is not supported"); // cp.async will automatically fill zero when out of bound init = nullptr; } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 9f4fd14cfa0..239843f52eb 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3624,9 +3624,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { inlineMost(); - // TODO: looks like this test will hang if I enable this - // tv0c->circularBuffer(/*number_of_stages=*/4); - // tv1c->circularBuffer(/*number_of_stages=*/4); + tv0c->circularBuffer(/*number_of_stages=*/4); + tv1c->circularBuffer(/*number_of_stages=*/4); auto inputs = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));