From 9cc373f39036af789fb1ffc1e06b23766996d3f4 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 19 Sep 2024 12:37:57 -0500 Subject: [PATCH] [Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577) --- csrc/rocm/attention.cu | 240 +++++++++++++------- csrc/rocm/ops.h | 3 +- csrc/rocm/torch_bindings.cpp | 3 +- tests/kernels/test_attention.py | 251 ++++++--------------- vllm/_custom_ops.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 28 +-- 6 files changed, 246 insertions(+), 283 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8fa7c862fbfa8..b48348a515c8d 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -18,8 +18,11 @@ #include #include #include +#include "cuda_compat.h" #include +#include "../attention/dtype_fp8.cuh" +#include "../quantization/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -38,7 +41,6 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define WARP_SIZE 64 #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support @@ -60,6 +62,8 @@ typedef struct _B16x8 { _B16x4 xy[2]; } _B16x8; +using _B8x8 = uint2; + ////// Non temporal load stores /////// template @@ -168,18 +172,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; #pragma unroll @@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; @@ -298,17 +324,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + - wg_start_kv_head_idx * kv_head_stride; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 - - const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } } float alibi_slope[QHLOOP]; @@ -322,30 +360,66 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } } } } + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } + } + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], @@ -794,14 +868,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { UNREACHABLE_CODE } @@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ + paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); -template +template void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, -#if 0 - torch::Tensor& qk_out, - torch::Tensor& softmax_out, -#endif - const c10::optional& alibi_slopes) { - + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -878,14 +949,10 @@ void paged_attention_custom_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); -#if 0 - T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); - T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); -#endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -972,32 +1039,32 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes); +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); -#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ break; \ case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ default: \ TORCH_CHECK(false, "Unsupported head size: ", head_size); \ @@ -1020,19 +1087,34 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - assert(kv_cache_dtype == "auto"); + const std::string& kv_cache_dtype, double k_scale, double v_scale) { const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 4a07a3f1775bd..9f085115a3956 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 082e314587908..a283d4263d293 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor context_lens, int block_size," " int max_context_len," " Tensor? alibi_slopes," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4bd6f7863a658..ecab512cba16f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -31,8 +31,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 - ] if not is_hip() else [64, 80, 96, 112, 128] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize( + "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -137,7 +137,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): pytest.skip() seed_everything(seed) @@ -206,7 +207,7 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) - elif version == "v2": + elif version in ("v2", "rocm"): num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -219,32 +220,61 @@ def test_paged_attention( dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) - - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, key_cache, - value_cache, num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") @@ -328,162 +358,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("version", ["rocm"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not is_hip(), reason="only for rocm") -def test_paged_attention_rocm( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, -) -> None: - seed_everything(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - #context_lens = [8192 for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) - #print('>>> ctx lens', context_lens) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # TODO(charlifu) enable fp8 kv cache - # Using default kv_scale - # kv_scale = 1.0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - PARTITION_SIZE_ROCM = 256 - num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // - PARTITION_SIZE_ROCM) - assert PARTITION_SIZE_ROCM % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - if version == "rocm": - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 2e-4, 1e-5 - if use_alibi: - if dtype == torch.half: - atol, rtol = 5e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) - - # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -491,7 +365,8 @@ def test_paged_attention_rocm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(is_hip(), reason="skip for rocm") +@pytest.mark.skipif(is_hip(), + reason="Xformers backend is not supported on ROCm.") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff5aa8bee3c27..678700055c992 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -146,12 +146,14 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype) + kv_cache_dtype, k_scale, v_scale) # pos encoding ops diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6bd276ade1d41..70e6857584ace 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,8 +17,8 @@ logger = init_logger(__name__) -_PARTITION_SIZE = 256 -ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName +_PARTITION_SIZE_ROCM = 512 +_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName class ROCmFlashAttentionBackend(AttentionBackend): @@ -489,14 +489,15 @@ def forward( num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, self.kv_cache_dtype, - gqa_ratio, decode_meta.max_decode_seq_len) + use_custom = _use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) if use_custom: max_seq_len = decode_meta.max_decode_seq_len - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - assert _PARTITION_SIZE % block_size == 0 + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, @@ -524,6 +525,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, + k_scale, + v_scale, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -580,12 +583,11 @@ def _sdpa_attention( return output -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, kv_cache_dtype: str, - gqa_ratio: int, max_seq_len: int) -> bool: +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) - return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)