diff --git a/CMakeLists.txt b/CMakeLists.txt index 546ef690dec0c..aaab472eef143 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,7 +221,8 @@ endif() set(VLLM_EXT_SRC "csrc/cache_kernels.cu" - "csrc/attention/attention_kernels.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cuh similarity index 64% rename from csrc/attention/attention_kernels.cu rename to csrc/attention/attention_kernels.cuh index 95f7b610c1933..563e1438f0b01 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cuh @@ -670,334 +670,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); - -// TODO(woosuk): Tune NUM_THREADS. -template -#else - int NUM_THREADS = 128> -#endif -void paged_attention_v1_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len - // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 120: - LAUNCH_PAGED_ATTENTION_V1(120); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V1(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v1_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); - -#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - if (is_block_sparse) { \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - } else { \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - } - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V1_LAUNCHER_BLOCK_SIZE) -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ - value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ - max_num_partitions); - -template -#else - int NUM_THREADS = 128, int PARTITION_SIZE = 512> -#endif -void paged_attention_v2_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 120: - LAUNCH_PAGED_ATTENTION_V2(120); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V2(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ - paged_attention_v2_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); - -#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - if (is_block_sparse) { \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - } else { \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - } - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& - tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, // [num_heads] - double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - const bool is_block_sparse = (blocksparse_vert_stride > 1); - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V2_LAUNCHER_BLOCK_SIZE) -} - #undef WARP_SIZE #undef MAX #undef MIN diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu new file mode 100644 index 0000000000000..69072c21903a9 --- /dev/null +++ b/csrc/attention/paged_attention_v1.cu @@ -0,0 +1,209 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \ + NUM_THREADS) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \ + IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \ + switch (num_threads) { \ + case 128: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 128); \ + break; \ + case 1024: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 1024); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \ + break; \ + } + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + true); \ + } else { \ + CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step, const int64_t num_threads) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu new file mode 100644 index 0000000000000..13e070f30f643 --- /dev/null +++ b/csrc/attention/paged_attention_v2.cu @@ -0,0 +1,227 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \ + NUM_THREADS, PARTITION_SIZE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \ + IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \ + switch (num_threads) { \ + case 128: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 128, 512); \ + break; \ + case 1024: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 1024, 1024); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \ + break; \ + } + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + true); \ + } else { \ + CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step, const int64_t num_threads) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 57f7a45785d11..4454eeffa102c 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -27,6 +27,11 @@ using __nv_bfloat162 = __hip_bfloat162; #include "quantization/fp8/nvidia/quant_utils.cuh" #endif +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + namespace vllm { template @@ -67,6 +72,8 @@ struct __align__(16) vec8_t { __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } }; +#ifdef __HIP__MI300_MI250__ + // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( @@ -110,6 +117,41 @@ __global__ void rms_norm_kernel( } } +#else + +// TODO(maleksan): Investigate why vectorization doesn't work for Navi. +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +#endif + template __global__ void scaled_rms_norm_kernel( c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] @@ -144,7 +186,7 @@ __global__ void scaled_rms_norm_kernel( /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need - to be implemented for now because the relevant type conversion + to be implemented for now because/error the relevant type conversion operators/constructors are not consistently implemented by HIP/CUDA, so a generic conversion via type casts cannot be implemented. diff --git a/csrc/ops.h b/csrc/ops.h index f98c817fda646..88eb3b6f84fbe 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -37,7 +37,7 @@ void paged_attention_v1( const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_head_sliding_step, const int64_t num_threads); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, @@ -48,7 +48,7 @@ void paged_attention_v2( const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_head_sliding_step, const int64_t num_threads); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6da7c1fa30257..e6f1c96e3201e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -33,7 +33,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); + " int blocksparse_head_sliding_step," + " int num_threads) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. @@ -47,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); + " int blocksparse_head_sliding_step," + " int num_threads) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops diff --git a/docs/source/dev/kernel/paged_attention.rst b/docs/source/dev/kernel/paged_attention.rst index ba4f7a2718158..65a7a1ce260f7 100644 --- a/docs/source/dev/kernel/paged_attention.rst +++ b/docs/source/dev/kernel/paged_attention.rst @@ -2,7 +2,7 @@ vLLM Paged Attention ==================== - Currently, vLLM utilizes its own implementation of a multi-head query - attention kernel (``csrc/attention/attention_kernels.cu``). + attention kernel (``csrc/attention/paged_attention_v1/2.cu``). This kernel is designed to be compatible with vLLM's paged KV caches, where the key and value cache are stored in separate blocks (note that this block concept differs from the GPU diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42700fee3bda0..a5161a1aa1278 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -130,7 +130,9 @@ def paged_attention_v1( seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step) + blocksparse_head_sliding_step, + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128) def paged_attention_v2( @@ -162,7 +164,9 @@ def paged_attention_v2( num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + blocksparse_block_size, blocksparse_head_sliding_step, + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128) def paged_attention_rocm( diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 27dda6f5536a5..5390d0b7dd46c 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -6,12 +6,13 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +from vllm.utils import is_navi if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 if not current_platform.is_rocm() else 1024 +_PARTITION_SIZE = 512 if not current_platform.is_rocm() or is_navi() else 1024 @dataclass