From ec4aaad8124baadc7954e30c612ca9444b22d7e7 Mon Sep 17 00:00:00 2001 From: rasmith Date: Sat, 21 Sep 2024 04:20:54 -0500 Subject: [PATCH] [Kernel][Triton][AMD] Remove tl.atomic_add from awq_gemm_kernel, 2-5x speedup MI300, minor improvement for MI250 (#8646) --- .../layers/quantization/awq_triton.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index d0b210c3a2747..bbb7fc8ad5087 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, c = accumulator.to(c_ptr.type.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) # qweights - [K , M // 8], int32 @@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor, split_k_iters, ) - result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + result = torch.zeros((split_k_iters, M, N), + dtype=scales.dtype, + device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N @@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor, BLOCK_SIZE_K=block_size_k, SPLIT_K=split_k_iters) + result = result.sum(0) + return result