From 531de0b80d2db4240e5b953dcd63842ea6cc304e Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Tue, 3 Dec 2024 11:51:51 +0800 Subject: [PATCH 01/11] Feat(kv store): add a CPU KV Cache Block Store Conflicts: vllm/attention/backends/abstract.py vllm/attention/backends/flash_attn.py vllm/core/scheduler.py vllm/model_executor/models/qwen2.py vllm/sequence.py vllm/worker/worker.py vllm/worker/worker_base.py Signed-off-by: Dahai Tang --- csrc/cache.h | 16 + csrc/cache_kernels.cu | 46 ++ csrc/kv_store/kv_store.cu | 451 +++++++++++++++++++ csrc/kv_store/kv_store.hpp | 52 +++ csrc/torch_bindings.cpp | 26 ++ examples/offline_inference.py | 13 +- vllm/_custom_ops.py | 24 + vllm/attention/backends/abstract.py | 2 + vllm/attention/backends/flash_attn.py | 62 ++- vllm/attention/backends/utils.py | 14 +- vllm/block.py | 3 +- vllm/config.py | 8 + vllm/core/block/cpu_gpu_block_allocator.py | 9 + vllm/core/block/naive_block.py | 7 + vllm/core/block/prefix_caching_block.py | 46 +- vllm/core/block_manager.py | 4 + vllm/core/scheduler.py | 185 ++++++-- vllm/engine/arg_utils.py | 6 + vllm/engine/llm_engine.py | 21 +- vllm/entrypoints/llm.py | 2 + vllm/model_executor/models/qwen2.py | 23 +- vllm/sequence.py | 4 + vllm/store/__init__.py | 7 + vllm/store/kv_store.py | 482 +++++++++++++++++++++ vllm/worker/model_runner.py | 2 + vllm/worker/worker.py | 61 ++- vllm/worker/worker_base.py | 59 ++- 27 files changed, 1585 insertions(+), 50 deletions(-) create mode 100644 csrc/kv_store/kv_store.cu create mode 100644 csrc/kv_store/kv_store.hpp create mode 100644 vllm/store/__init__.py create mode 100644 vllm/store/kv_store.py diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..ae99cdbaf9487 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,6 +8,22 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); +void kv_store_copy_incomplete_blocks(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& incomplete_block_mapping); +void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& block_mapping); + +void kv_store_copy_blocks2GPU(torch::Tensor& src, + std::vector const& dst, + const int64_t num_layers, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + std::vector const& events, + const bool is_batch_layer); + // Note: the key_caches and value_caches vectors are constant but // not the Tensors they contain. The vectors need to be const refs // in order to satisfy pytorch's C++ operator registration code. diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..81e3399324de9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "cuda_compat.h" #include "dispatch_utils.h" @@ -21,6 +22,14 @@ typedef __hip_bfloat16 __nv_bfloat16; #endif +#include "kv_store/kv_store.hpp" + +namespace { + +KVStore kv_store; + +}; // namespace + void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); @@ -62,6 +71,43 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +// src layout: [2, num_blocks, block_size, num_kv_heads, head_size] +// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +void kv_store_copy_incomplete_blocks(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& incomplete_block_mapping) { + kv_store.CopyIncompleteBlocks(src, dst, layer_id, incomplete_block_mapping); +} + +// src layout: [2, num_blocks, block_size, num_kv_heads, head_size] +// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& block_mapping) { + kv_store.CopyBlocks2CPU(src, dst, layer_id, block_mapping); +} + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +void kv_store_copy_blocks2GPU(torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layers, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + std::vector const& events, + const bool is_batch_layer) { + if (is_batch_layer) { + const int64_t num_requests = req_ids.size(0); + kv_store.CopyBlocks2GPUBatch(src, kv_caches, num_layers, block_mapping, + block_offsets, num_requests, events); + } + else { + kv_store.CopyLayerBlocks2GPU(src, kv_caches, num_layers, block_mapping, + block_offsets, req_ids, events); + } +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/csrc/kv_store/kv_store.cu b/csrc/kv_store/kv_store.cu new file mode 100644 index 0000000000000..14afff472a553 --- /dev/null +++ b/csrc/kv_store/kv_store.cu @@ -0,0 +1,451 @@ +#include "kv_store.hpp" + + + +#define CHECK_CUDA(x) {\ + cudaError_t err = (x);\ + if (err != cudaSuccess) {\ + std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ + std::abort();\ + }\ +} + +// src layout: [2, num_blocks, block_size, num_kv_heads, head_size] +// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +void KVStore::CopyIncompleteBlocks(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& incomplete_block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else { + TORCH_CHECK(false, "only support copy from GPU to CPU"); + } + TORCH_CHECK(incomplete_block_mapping.device().is_cpu(), + "block_mapping must be on CPU"); + + const int64_t slot_size_in_bytes = src.element_size() * src[0][0][0].numel(); + const at::cuda::OptionalCUDAGuard device_guard(src_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t num_items = incomplete_block_mapping.size(0); + for (size_t i = 0; i < num_items; i++) { + int64_t src_block = incomplete_block_mapping[i][0].item(); + int64_t start_offset = incomplete_block_mapping[i][1].item(); + int64_t end_offset = incomplete_block_mapping[i][2].item(); + int64_t dst_block = incomplete_block_mapping[i][3].item(); + int64_t copy_nbytes = (end_offset - start_offset ) * slot_size_in_bytes; + char* src_ptr = reinterpret_cast(src[0][src_block].data_ptr()); + char* dst_ptr = reinterpret_cast( + dst[dst_block][0][layer_id].data_ptr()); + start_offset *= slot_size_in_bytes; + CHECK_CUDA(cudaMemcpyAsync(dst_ptr + start_offset, src_ptr + start_offset, + copy_nbytes, memcpy_type, stream)); + src_ptr = reinterpret_cast(src[1][src_block].data_ptr()); + dst_ptr = reinterpret_cast(dst[dst_block][1][layer_id].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr + start_offset, src_ptr + start_offset, + copy_nbytes, memcpy_type, stream)); + } +} + +// src layout: [2, num_blocks, block_size, num_kv_heads, head_size] +// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +void KVStore::CopyBlocks2CPU(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else { + TORCH_CHECK(false, "only support copy from GPU to CPU"); + } + TORCH_CHECK(block_mapping.device().is_cpu(), + "block_mapping must be on CPU"); + const int64_t src_block_numel = src[0][0].numel(); + const int64_t dst_block_numel = dst[0][0][0].numel(); + TORCH_CHECK(src_block_numel == dst_block_numel, + "src and dst must have the same number of elements"); + const int64_t block_size_in_bytes = src.element_size() * src_block_numel; + const at::cuda::OptionalCUDAGuard device_guard(src_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + // key + char* src_ptr = reinterpret_cast( + src[0][src_block_number].data_ptr()); + char* dst_ptr = reinterpret_cast( + dst[dst_block_number][0][layer_id].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync( + dst_ptr, src_ptr, block_size_in_bytes, memcpy_type, stream)); + // value + src_ptr = reinterpret_cast( + src[1][src_block_number].data_ptr()); + dst_ptr = reinterpret_cast( + dst[dst_block_number][1][layer_id].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync( + dst_ptr, src_ptr, block_size_in_bytes, memcpy_type, stream)); + } +} + + +namespace vllm { + +// Grid: (num_layers, num_blocks) +template +__global__ void kv_store_copy_blocks_kernel( + scalar_t *src, + int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *block_mapping, const int64_t *block_offsets, + const int request_idx, + const int64_t numel_per_block) { + int64_t layer_idx = blockIdx.x; + int64_t pair_idx = blockIdx.y; + int64_t num_layer = gridDim.x; + scalar_t *key_cache = + reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t *value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t block_mapping_idx = block_offsets[request_idx] + pair_idx; + int64_t dst_block_number = block_mapping[2 * block_mapping_idx + 1]; + scalar_t *key_block = key_cache + dst_block_number * numel_per_block; + scalar_t *src_key_block = (src + pair_idx * 2 * num_layer * numel_per_block + + layer_idx * numel_per_block); + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + key_block[i] = src_key_block[i]; + } + scalar_t *value_block = value_cache + dst_block_number * numel_per_block; + scalar_t *src_value_block = (src + + (pair_idx * 2 + 1) * num_layer * numel_per_block + + layer_idx * numel_per_block); + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + value_block[i] = src_value_block[i]; + } +} + +} // namespace vllm + +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + os << "["; + for (size_t i = 0; i < vec.size(); i++) { + os << vec[i]; + if (i != vec.size() - 1) { + os << ", "; + } + } + os << "]"; + return os; +} + + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layers, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const int64_t num_requests, + std::vector const& events) { + torch::Device src_device = src.device(); + torch::Device dst_device = kv_caches[0].device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "only support copy from CPU to GPU"); + } + TORCH_CHECK(block_mapping.device().is_cpu(), + "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), + "block_offsets must be on CPU"); + const at::cuda::OptionalCUDAGuard device_guard(dst_device); + auto stream = at::cuda::getCurrentCUDAStream(); + auto block_mapping_gpu = block_mapping.to(dst_device, + block_mapping.scalar_type(), + /*non_blocking=*/true); + auto block_offsets_gpu = block_offsets.to(dst_device, + block_offsets.scalar_type(), + /*non_blocking=*/true); + // Create data structures for the kernel. + // Create an array of pointers to the key and value caches. + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = + reinterpret_cast(kv_caches[layer_idx][0].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(kv_caches[layer_idx][1].data_ptr()); + } + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(dst_device, /*non_blocking=*/true); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(dst_device, /*non_blocking=*/true); + + for (size_t i = 0; i < num_requests; i++) { + const int64_t start_idx = block_offsets[i].item(); + const int64_t end_idx = block_offsets[i + 1].item(); + const int64_t num_blocks = end_idx - start_idx; + auto options = torch::TensorOptions() + .dtype(kv_caches[0].dtype()) + .device(dst_device); + std::vector shape = src.sizes().vec(); + shape[0] = num_blocks; + // XXX: may cause out of memory in VLLM framework + torch::Tensor trans_buffer = torch::empty(shape, options); + for (size_t j = 0; j < num_blocks; j++) { + int64_t idx = (start_idx + j); + int64_t src_block_number = block_mapping[idx][0].item(); + char* src_ptr = reinterpret_cast(src[src_block_number].data_ptr()); + char* dst_ptr = reinterpret_cast(trans_buffer[j].data_ptr()); + int64_t trans_nbytes = src[0].element_size() * src[0].numel(); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, trans_nbytes, + memcpy_type, stream)); + } + const int numel_per_block = src[0][0][0].numel(); + const dim3 grid(num_layers, num_blocks); + const dim3 block(std::min(1024, numel_per_block)); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + src.scalar_type(), "kv_store_copy_blocks_kernel", ([&] { + vllm::kv_store_copy_blocks_kernel<<>>( + trans_buffer.data_ptr(), + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping_gpu.data_ptr(), + block_offsets_gpu.data_ptr(), + i, + numel_per_block); + })); + CHECK_CUDA(cudaEventRecord(reinterpret_cast(events[i]), stream)); + } +} + +namespace vllm { + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// key layout: [num_blocks, block_size, num_kv_heads, head_size] +// value layout: [num_blocks, block_size, num_kv_heads, head_size] +// Grid: (num_blocks) +template +__global__ void kv_store_copy_blocks_kernel( + scalar_t* src, + scalar_t* key_cache, + scalar_t* value_cache, + const int64_t* block_mapping, + const int64_t* block_offsets, + const int request_idx, + const int64_t layer_id, + const int64_t num_layer, + const int64_t numel_per_block) { + + int pair_idx = blockIdx.x; + int64_t block_mapping_idx = block_offsets[request_idx] + pair_idx; + int64_t src_block_number = block_mapping[2 * block_mapping_idx]; + int64_t dst_block_number = block_mapping[2 * block_mapping_idx + 1]; + scalar_t* src_key_block = src + + src_block_number * 2 * num_layer * numel_per_block + + layer_id * numel_per_block; + scalar_t* dst_key_block = key_cache + + dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + dst_key_block[i] = src_key_block[i]; + } + scalar_t* src_value_block = src + + (src_block_number * 2 + 1) * num_layer * numel_per_block + + layer_id * numel_per_block; + scalar_t* dst_value_block = value_cache + + dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + dst_value_block[i] = src_value_block[i]; + } +} + +}; // namespace vllm + +namespace { + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +void CopyLayerBlocks2GPUKernelFunc( + const torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layer, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + const std::vector& events, + const at::cuda::CUDAStream& stream) { // is the current stream + size_t num_requests = req_ids.size(0); + const int numel_per_block = src[0][0][0].numel(); + const int64_t block_nbytes = numel_per_block * src.element_size(); + torch::Device src_device = src.device(); + torch::Device dst_device = kv_caches[0].device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "only support copy from CPU to GPU"); + } + TORCH_CHECK(block_mapping.device().is_cpu(), + "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), + "block_offsets must be on CPU"); + TORCH_CHECK(req_ids.device().is_cpu(), + "req_ids must be on CPU"); + const at::cuda::OptionalCUDAGuard device_guard(dst_device); + auto block_mapping_gpu = block_mapping.to(dst_device, + block_mapping.scalar_type(), + /*non_blocking=*/true); + auto block_offsets_gpu = block_offsets.to(dst_device, + block_offsets.scalar_type(), + /*non_blocking=*/true); + for (size_t i = 0; i < num_requests; i++) { + const int64_t req_id = req_ids[i].item(); + const int64_t start_idx = block_offsets[i].item(); + const int64_t end_idx = block_offsets[i + 1].item(); + const int64_t num_blocks = end_idx - start_idx; + + for (int64_t layer_id = 0; layer_id < num_layer; layer_id++) { + if (num_blocks >= 2) { // if blocks are too many, use kernel + const dim3 grid(num_blocks); + const dim3 block(std::min(1024, numel_per_block)); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + src.scalar_type(), "kv_store_copy_blocks_kernel", ([&] { + vllm::kv_store_copy_blocks_kernel + <<>>( + src.data_ptr(), + kv_caches[layer_id][0].data_ptr(), + kv_caches[layer_id][1].data_ptr(), + block_mapping_gpu.data_ptr(), + block_offsets_gpu.data_ptr(), + i, + layer_id, + num_layer, + numel_per_block); + })); + } + else { + for (size_t j = 0; j < num_blocks; j++) { + int64_t idx = (start_idx + j); + int64_t src_block_number = block_mapping[idx][0].item(); + int64_t dst_block_number = block_mapping[idx][1].item(); + char* src_ptr = reinterpret_cast( + src[src_block_number][0][layer_id].data_ptr()); + char* dst_ptr = reinterpret_cast( + kv_caches[layer_id][0][dst_block_number].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, + block_nbytes, memcpy_type, stream)); + src_ptr = reinterpret_cast( + src[src_block_number][1][layer_id].data_ptr()); + dst_ptr = reinterpret_cast( + kv_caches[layer_id][1][dst_block_number].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, + block_nbytes, memcpy_type, stream)); + } + } + CHECK_CUDA(cudaEventRecord( + reinterpret_cast(events[i * num_layer + layer_id]), + stream)); + } + } +} + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +void CopyLayerBlocks2GPUThreadFunc( + const torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layer, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + const std::vector& events, + const at::cuda::CUDAStream& stream) { + size_t num_requests = req_ids.size(0); + const int64_t block_nbytes = + kv_caches[0][0][0].numel() * kv_caches[0].element_size(); + torch::Device src_device = src.device(); + torch::Device dst_device = kv_caches[0].device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "only support copy from CPU to GPU"); + } + TORCH_CHECK(block_mapping.device().is_cpu(), + "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), + "block_offsets must be on CPU"); + TORCH_CHECK(req_ids.device().is_cpu(), + "req_ids must be on CPU"); + const at::cuda::OptionalCUDAGuard device_guard(dst_device); + for (size_t i = 0; i < num_requests; i++) { + const int64_t req_id = req_ids[i].item(); + const int64_t start_idx = block_offsets[i].item(); + const int64_t end_idx = block_offsets[i + 1].item(); + const int64_t num_blocks = end_idx - start_idx; + for (int64_t layer_id = 0; layer_id < num_layer; layer_id++) { + for (size_t j = 0; j < num_blocks; j++) { + int64_t idx = (start_idx + j); + int64_t src_block_number = block_mapping[idx][0].item(); + int64_t dst_block_number = block_mapping[idx][1].item(); + char* src_ptr = reinterpret_cast( + src[src_block_number][0][layer_id].data_ptr()); + char* dst_ptr = reinterpret_cast( + kv_caches[layer_id][0][dst_block_number].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, + block_nbytes, memcpy_type, stream)); + src_ptr = reinterpret_cast( + src[src_block_number][1][layer_id].data_ptr()); + dst_ptr = reinterpret_cast( + kv_caches[layer_id][1][dst_block_number].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, + block_nbytes, memcpy_type, stream)); + } + CHECK_CUDA(cudaEventRecord( + reinterpret_cast(events[i * num_layer + layer_id]), + stream)); + } + } +} + +}; // namespace + + +// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +void KVStore::CopyLayerBlocks2GPU( + torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layer, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + const std::vector& events) { + if (block_mapping.size(0) == 0) { + return; + } + auto stream = at::cuda::getCurrentCUDAStream(); + /* + if (_copy_thread.joinable()) { + _copy_thread.join(); + } + _copy_thread = std::thread(CopyLayerBlocks2GPUThreadFunc, + src, kv_caches, num_layer, + block_mapping.clone(), + block_offsets.clone(), + req_ids.clone(), + events, stream); + */ + CopyLayerBlocks2GPUKernelFunc(src, kv_caches, num_layer, + block_mapping, block_offsets, req_ids, events, stream); +} + diff --git a/csrc/kv_store/kv_store.hpp b/csrc/kv_store/kv_store.hpp new file mode 100644 index 0000000000000..5b7c722ab865d --- /dev/null +++ b/csrc/kv_store/kv_store.hpp @@ -0,0 +1,52 @@ +#include +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include +#include +#include +#include +#include + +class KVStore { + public: + KVStore() = default; + void CopyIncompleteBlocks(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& incomplete_block_mapping); + + void CopyBlocks2CPU(torch::Tensor& src, torch::Tensor& dst, + const int64_t layer_id, + const torch::Tensor& block_mapping); + + void CopyBlocks2GPUBatch(torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layers, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const int64_t num_requests, + std::vector const& events); + + void CopyLayerBlocks2GPU( + torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layer, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + const std::vector& events); + + ~KVStore() { + if (_copy_thread.joinable()) { + _copy_thread.join(); + } + } + + private: + std::thread _copy_thread; + +}; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c92773a..ac8c2fbd91f27 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -404,6 +404,32 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + // Copy the incomplete blocks from src to dst. + cache_ops.def( + "kv_store_copy_incomplete_blocks(Tensor src, Tensor! dst, " + " int layer_id, " + " Tensor incomplete_block_mapping) -> ()"); + cache_ops.impl("kv_store_copy_incomplete_blocks", torch::kCUDA, + &kv_store_copy_incomplete_blocks); + + // Copy the kv cache blocks from src(GPU) to dst(CPU), used for kv store. + cache_ops.def( + "kv_store_copy_blocks2CPU(Tensor src, Tensor! dst, " + " int layer_id, Tensor block_mapping) -> ()"); + cache_ops.impl("kv_store_copy_blocks2CPU", torch::kCUDA, + &kv_store_copy_blocks2CPU); + + // Copy the kv cache blocks from src(CPU) to dst(GPU), used for kv store. + cache_ops.def( + "kv_store_copy_blocks2GPU(Tensor src, Tensor[](b!) dst, " + " int num_layers," + " Tensor block_mapping, Tensor block_offsets," + " Tensor req_ids," + " int[] events," + " bool is_batch_layer) -> ()"); + cache_ops.impl("kv_store_copy_blocks2GPU", torch::kCUDA, + &kv_store_copy_blocks2GPU); + // Copy the cache blocks from src to dst. cache_ops.def( "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 23cc6e8539431..97447f703e667 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,13 @@ from vllm import LLM, SamplingParams +def print_outputs(llm, outputs): + for output in outputs: + prompt = output.prompt + token_ids = output.outputs[0].token_ids + generated_text = llm.get_tokenizer().decode(token_ids) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Sample prompts. prompts = [ "Hello, my name is", @@ -16,7 +24,4 @@ # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file +print_outputs(llm, outputs) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c192c9a7b0e4d..ccb0e70e88074 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -923,6 +923,30 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) +def kv_store_copy_incomplete_blocks(src: torch.Tensor, dst: torch.Tensor, + layer_id: int, + incomplete_block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.kv_store_copy_incomplete_blocks(src, dst, + layer_id, + incomplete_block_mapping) + +def kv_store_copy_blocks2CPU(src: torch.Tensor, dst: torch.Tensor, + layer_id: int, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.kv_store_copy_blocks2CPU(src, dst, layer_id, + block_mapping) + +def kv_store_copy_blocks2GPU(src: torch.Tensor, dst: List[torch.Tensor], + num_layers: int, + block_mapping: torch.Tensor, + block_offsets: torch.Tensor, + req_ids: torch.Tensor, + events: List[int], # the pointer of cudaEvent_t + is_batch_layer: bool) -> None: + torch.ops._C_cache_ops.kv_store_copy_blocks2GPU( + src, dst, num_layers, + block_mapping, block_offsets, + req_ids, events, is_batch_layer) def convert_fp8(output: torch.Tensor, input: torch.Tensor, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..35db3c21b2336 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) +from vllm.store.kv_store import KVStoreMeta import torch @@ -114,6 +115,7 @@ class AttentionMetadata: # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + kv_store_meta: KVStoreMeta = None # The index maps that relate multi-modal embeddings to the corresponding # placeholders. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..90e474e515360 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch +import numpy as np from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -16,6 +17,7 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.store.kv_store import KVStoreMeta from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -235,6 +237,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + kv_store_meta=self.kv_store_meta, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -285,6 +288,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, + kv_store_meta=self.kv_store_meta, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -384,11 +388,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 self.has_prefix_cache_hit = False + self.range_list: List[Tuple[int, int, List[int], bool]] = [] self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size + self.enable_kv_store = input_builder.enable_kv_store def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -446,9 +452,12 @@ def _add_seq_group( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + (range_start, range_end, block_table) = compute_slot_mapping( + is_profile_run, self.slot_mapping, seq_id, + curr_seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + self.range_list.append((range_start, range_end, + block_table, is_prompt)) def _get_graph_runner_block_tables( self, num_seqs: int, @@ -472,7 +481,7 @@ def _get_graph_runner_block_tables( i, :max_blocks] = block_table[:max_blocks] return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) + device=self.runner.device, non_blocking=True) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): @@ -543,6 +552,48 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.multimodal_placeholder_maps.items() } + incomplete_put_block_ids = [] + put_block_ids = [] + assert(len(self.range_list) == len(self.block_tables)) + if (self.enable_kv_store == True): + for (range_start, range_end, seq_block_table, is_prompt) in \ + self.range_list: + if (range_start == range_end) or (is_prompt == False): + continue + block_size = self.block_size + range_end -= 1 + range_start_block_id = range_start // block_size + range_end_block_id = range_end // block_size + range_start_block_offset = range_start % block_size + range_end_block_offset = range_end % block_size + 1 + if (range_start_block_id == range_end_block_id): + incomplete_put_block_ids.append( + [seq_block_table[range_start_block_id], + range_start_block_offset, range_end_block_offset]) + else: + if (range_start_block_offset == 0): + put_block_ids.append( + seq_block_table[range_start_block_id]) + else: + incomplete_put_block_ids.append( + [seq_block_table[range_start_block_id], + range_start_block_offset, block_size]) + put_block_ids.extend( + seq_block_table[ + range_start_block_id + 1:range_end_block_id]) + if (range_end_block_offset == block_size): + put_block_ids.append( + seq_block_table[range_end_block_id]) + else: + incomplete_put_block_ids.append( + [seq_block_table[range_end_block_id], + 0, range_end_block_offset]) + incomplete_put_block_ids_numpy = np.array(incomplete_put_block_ids) + put_block_ids_numpy = np.array(put_block_ids) + incomplete_put_block_ids_cpu = torch.from_numpy( + incomplete_put_block_ids_numpy).to("cpu") + put_block_ids_cpu = torch.from_numpy(put_block_ids_numpy).to("cpu") + return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -560,6 +611,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, + kv_store_meta=KVStoreMeta(incomplete_put_block_ids_cpu, + put_block_ids_cpu, + torch.Tensor()) ) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..174b9b70fa771 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.store.kv_store import KVStoreMeta if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -78,9 +79,11 @@ def _compute_slot_mapping_numpy(slot_mapping: List[int], def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, - block_tables: Dict[int, List[int]]): + block_tables: Dict[int, List[int]]) \ + -> (int, int, List[int]): """ Compute slot mapping. + return: the range_start and range_end """ if is_profile_run: # During memory profiling, the block tables are not @@ -88,7 +91,7 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], # slot mapping. # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([PAD_SLOT_ID] * seq_len) - return + return (0, 0, []) # Mask the [0, start_idx) tokens of the prompt with # PAD_SLOT_ID, where start_idx is max(0, seq_len - @@ -112,6 +115,12 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], else: _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, range_end, block_size) + block_table = block_table[ + (range_start // block_size):((range_end - 1) // block_size) + 1] + start_offset = (range_start // block_size) * block_size + range_start -= start_offset + range_end -= start_offset + return (range_start, range_end, block_table) TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') @@ -326,6 +335,7 @@ def graph_capture_get_metadata_for_batch( seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], + kv_store_meta = KVStoreMeta.null(), use_cuda_graph=True, ) if is_encoder_decoder_model: diff --git a/vllm/block.py b/vllm/block.py index 47c381c19383b..2dd4f06d10ab4 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -31,6 +31,7 @@ def __init__( def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' + f'block_hash={self.block_hash}, ' f'num_hashed_tokens={self.num_hashed_tokens}, ' f'ref_count={self.ref_count}, ' f'last_accessed={self.last_accessed}, ' @@ -38,7 +39,7 @@ def __repr__(self) -> str: class BlockTable: - """Holds a list of blocks with caching of their associated block_ids + """Holds a list of blocks with caching of their associated block_ids """ def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None): diff --git a/vllm/config.py b/vllm/config.py index 5f50d65ec87e1..910239844f560 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -754,6 +754,7 @@ def __init__( gpu_memory_utilization: float, swap_space: float, cache_dtype: str, + kv_store_space: float, is_attention_free: bool = False, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, @@ -777,6 +778,13 @@ def __init__( # Will be set after profiling. self.num_gpu_blocks: Optional[int] = None self.num_cpu_blocks: Optional[int] = None + # for vllm.store.kv_store.KVBlockStore + self.enable_kv_store = False + self.kv_store_space_bytes = int(kv_store_space * GiB_bytes) + if (kv_store_space > 0): + self.enable_kv_store = True + self.kv_store = None + self.kv_store_manager = None def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3197af3c2b7a4..5d3751aa5070a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -26,6 +26,7 @@ def create( num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, + kv_store_manager: Optional["KVStoreManager"], ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -67,6 +68,7 @@ def create( num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, + kv_store_manager=kv_store_manager, ) cpu_allocator: BlockAllocator = NaiveBlockAllocator( @@ -80,6 +82,7 @@ def create( num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, + kv_store_manager=kv_store_manager, ) cpu_allocator = PrefixCachingBlockAllocator( @@ -306,6 +309,12 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) + def mark_blocks_as_cached(self, blocks: List[Block]) -> None: + """Mark blocks as cached, only use for prefix caching with KV Store.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_cached(blocks) + def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index a2af5ad6362c1..fd00ee67f61cc 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -4,6 +4,7 @@ from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.store.kv_store import KVBlockStoreManager Refcount = int @@ -32,6 +33,7 @@ def __init__( block_size: int, block_ids: Optional[Iterable[int]] = None, block_pool: Optional[BlockPool] = None, + kv_store_manager: Optional[KVBlockStoreManager] = None, ): if block_ids is None: block_ids = range(num_blocks) @@ -46,6 +48,7 @@ def __init__( self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.kv_store_manager = kv_store_manager if block_pool is None: extra_factor = 4 @@ -101,6 +104,8 @@ def allocate_immutable_blocks( block_size=self._block_size, physical_block_id=block_ids[i]) blocks.append(prev_block) + if (self.kv_store_manager is not None): + self.kv_store_manager.allocate_block(-1, block_ids[i]) return blocks @@ -123,6 +128,8 @@ def allocate_mutable_block(self, token_ids=[], block_size=self._block_size, physical_block_id=block_id) + if (self.kv_store_manager is not None): + self.kv_store_manager.allocate_block(-1, block_id) return block def _allocate_block_id(self) -> BlockId: diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index b736167f6ceb4..b653f7042d1a1 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -13,6 +13,10 @@ NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.sequence import Sequence +from vllm.store.kv_store import KVBlockStoreManager + +# NOTE: for debug the kvstore, may cause some assert errors +only_enable_cpu_kvstore = False PrefixHash = int @@ -66,6 +70,7 @@ def __init__( num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + kv_store_manager: Optional[KVBlockStoreManager] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): if block_ids is None: @@ -117,6 +122,8 @@ def __init__( self.metric_data = CacheMetricData() + self.kv_store_manager = kv_store_manager + # Implements Block.Factory. def _create_block( self, @@ -163,13 +170,29 @@ def allocate_immutable_block(self, physical_block_id=None) assert block.content_hash is not None - cached_block_id = self._cached_blocks.get(block.content_hash, None) - if cached_block_id is not None: - self.metric_data.query(hit=True) - block.block_id = cached_block_id - self._incr_refcount_cached_block(block) - return block + if (only_enable_cpu_kvstore == False) or \ + (self.kv_store_manager == None): + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + self.metric_data.query(hit=True) + block.block_id = cached_block_id + self._incr_refcount_cached_block(block) + return block self.metric_data.query(hit=False) + if (self.kv_store_manager != None) and \ + (self.kv_store_manager.has(block.content_hash)): + block_id = self._allocate_block_id() + block.block_id = block_id + # print("Found in kv store with hash: ", block.content_hash, + # "block_id: ", block_id) + self.kv_store_manager.remap_block_id( + block.content_hash, block.block_id) + self.kv_store_manager.open_send_flag(block.block_id) + self._cached_blocks[block.content_hash] = block_id + block.computed = True + self._block_tracker[block_id].computed = True + return block + self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -213,6 +236,8 @@ def allocate_mutable_block(self, physical_block_id=block_id) assert not block.computed assert block.content_hash is None + if (self.kv_store_manager != None): + self.kv_store_manager.allocate_block(-1, block.block_id) return block def _incr_refcount_cached_block(self, block: Block) -> None: @@ -445,6 +470,10 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: assert self._refcounter.get(block.block_id) > 0 if block.content_hash not in self._cached_blocks: + if (self.kv_store_manager != None): + self.kv_store_manager.add_hash_map( + block.content_hash, block.block_id) + # No cached content hash => Set this block as cached. # Note that this block cannot be marked as computed yet # because other sequences in the same batch cannot reuse @@ -464,6 +493,11 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: # Note that in this case, the block is marked as computed self._incr_refcount_cached_block(block) + if (self.kv_store_manager != None): + self.kv_store_manager.add_hash_map( + block.content_hash, block.block_id) + + return block.block_id def cow_block_if_not_appendable(self, block: Block) -> BlockId: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 209487c6b4f9e..27a214f359c29 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -12,6 +12,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +from vllm.store.kv_store import KVBlockStoreManager SeqId = int EncoderSeqId = str @@ -62,6 +63,7 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + kv_store_manager : KVBlockStoreManager, watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, @@ -95,6 +97,7 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, + kv_store_manager=kv_store_manager, ) self.block_tables: Dict[SeqId, BlockTable] = {} @@ -104,6 +107,7 @@ def __init__( self.block_allocator, self.block_size, self.enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) + self.kv_store_manager = kv_store_manager def can_allocate(self, seq_group: SequenceGroup, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..d48becd166e71 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -16,7 +16,8 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import Device, PyObjectCache +from vllm.utils import Device, PyObjectCache, set_abort_request_id +from vllm.store.kv_store import KVBlockStoreManager,BlockMappingFromCPU logger = init_logger(__name__) @@ -149,6 +150,7 @@ class SchedulerOutputs: # The number of requests in the running queue running_queue_size: int preempted: int + kv_store_block_mapping_from_cpu: BlockMappingFromCPU def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -274,6 +276,7 @@ class SchedulerPrefillOutputs: # Ignored sequence groups. ignored_seq_groups: List[SequenceGroup] num_lookahead_slots: int + kv_store_block_mapping_from_cpu: BlockMappingFromCPU @classmethod def create_empty(cls) -> "SchedulerPrefillOutputs": @@ -281,6 +284,7 @@ def create_empty(cls) -> "SchedulerPrefillOutputs": seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=0, + kv_store_block_mapping_from_cpu=None, ) @@ -316,6 +320,7 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + kv_store_manager: KVBlockStoreManager, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, output_proc_callback: Optional[Callable] = None, @@ -332,6 +337,7 @@ def __init__( or self.cache_config.is_attention_free): version = "placeholder" + logger.info("Using BlockSpaceManager version %s", version) BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) @@ -348,12 +354,15 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, + kv_store_manager=kv_store_manager, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) + self.kv_store_manager = kv_store_manager # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. self.waiting: Deque[SequenceGroup] = deque() + self.kv_store_waiting: Deque[SequenceGroup] = deque() # Sequence groups in the RUNNING state. # Contain decode requests. self.running: Deque[SequenceGroup] = deque() @@ -453,7 +462,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): request_id = (request_id, ) request_ids = set(request_id) - for state_queue in [self.waiting, self.running, self.swapped]: + for state_queue in [self.waiting, self.running, + self.swapped, self.kv_store_waiting]: aborted_groups: List[SequenceGroup] = [] for seq_group in state_queue: if not request_ids: @@ -490,13 +500,14 @@ def _free_seq_group_cross_attn_blocks( def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( - self.swapped) != 0 + self.swapped) != 0 or len(self.kv_store_waiting) != 0 def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) def get_num_unfinished_seq_groups(self) -> int: - return len(self.waiting) + len(self.running) + len(self.swapped) + return len(self.waiting) + len(self.running) + len(self.swapped) + \ + len(self.kv_store_waiting) def get_and_reset_finished_requests_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" @@ -523,7 +534,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -910,9 +921,87 @@ def _schedule_prefills( ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + kv_store_waiting_queue = self.kv_store_waiting waiting_queue = self.waiting + kv_store_block_mapping = [] + kv_store_block_mapping_offset = [] + kv_store_block_mapping_req_ids = [] + kv_store_block_mapping_cnt = 0 leftover_waiting_sequences: Deque[SequenceGroup] = deque() + kv_store_leftover_waiting_sequences: Deque[SequenceGroup] = deque() + + def _stop_schedule_prefill(num_new_tokens_uncached, + num_new_seqs, + max_num_batched_tokens, + budget): + if (budget.num_batched_tokens >= + self.scheduler_config.max_num_batched_tokens): + return True + if (num_new_tokens_uncached == 0 or + not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs)): + return True + return False + + kv_store_tmp_queue : Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and kv_store_waiting_queue: + if budget.num_curr_prefill_seqs >= self.scheduler_config.max_num_prefill_seqs: + break + + seq_group = kv_store_waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, enable_chunking, + budget)) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + num_new_seqs = seq_group.get_max_num_running_seqs() + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + kv_store_tmp_queue.appendleft(seq_group) + kv_store_waiting_queue.popleft() + continue + + if (_stop_schedule_prefill(num_new_tokens_uncached, + num_new_seqs, + self.scheduler_config.max_num_batched_tokens, budget)): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + + kv_store_waiting_queue.popleft() + self._set_running(seq_group) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + budget.add_num_prefill_seqs(num_new_seqs) + kv_store_waiting_queue.extendleft(kv_store_tmp_queue) + while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -976,25 +1065,19 @@ def _schedule_prefills( waiting_queue.popleft() continue - if (budget.num_batched_tokens >= - self.scheduler_config.max_num_batched_tokens): - # We've reached the budget limit - since there might be - # continuous prefills in the running queue, we should break - # to avoid scheduling any new prefills. - break + if (self.kv_store_manager != None): + self.kv_store_manager.is_prefill = seq_group.is_prefill() - num_new_seqs = seq_group.get_max_num_running_seqs() - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - break + block_mapping_from_cpu = [] + self._allocate(seq_group) - # Can schedule this request. - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() - self._allocate_and_set_running(seq_group) + if (self.kv_store_manager != None): + block_ids = self.block_manager.get_block_table( + seq_group.get_seqs()[0]) + block_mapping_from_cpu = \ + self.kv_store_manager.get_block_mapping_from_python( + block_ids) + self.kv_store_manager.is_prefill = False if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -1013,6 +1096,33 @@ def _schedule_prefills( is_multi_step=self.scheduler_config.is_multi_step, enable_chunking=enable_chunking) + if (len(block_mapping_from_cpu) > 0): + waiting_queue.popleft() + kv_store_leftover_waiting_sequences.appendleft(seq_group) + kv_store_block_mapping.extend( + block_mapping_from_cpu) + kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt) + kv_store_block_mapping_req_ids.append( + seq_group.get_seqs()[0].seq_id) + kv_store_block_mapping_cnt += len(block_mapping_from_cpu) + continue + + num_new_seqs = seq_group.get_max_num_running_seqs() + if (_stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs, + self.scheduler_config.max_num_batched_tokens, + budget)): + # let it to the next running one + waiting_queue.popleft() + kv_store_leftover_waiting_sequences.appendleft(seq_group) + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + + waiting_queue.popleft() + self._set_running(seq_group) + seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -1023,20 +1133,34 @@ def _schedule_prefills( ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + self.kv_store_waiting.extendleft(kv_store_leftover_waiting_sequences) + # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: self.prev_prompt = True + if (self.kv_store_manager != None) and \ + (len(kv_store_block_mapping) > 0): + self.kv_store_manager.close_send_flags( + [items[1] + for items in kv_store_block_mapping]) + + kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt) + kv_store_block_mapping_from_cpu = BlockMappingFromCPU( + kv_store_block_mapping, kv_store_block_mapping_offset, + kv_store_block_mapping_req_ids) + return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) + is_prefill=True, enable_chunking=enable_chunking), + kv_store_block_mapping_from_cpu=kv_store_block_mapping_from_cpu) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1137,11 +1261,13 @@ def _schedule_default(self) -> SchedulerOutputs: num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), preempted=preempted, + kv_store_block_mapping_from_cpu= \ + prefills.kv_store_block_mapping_from_cpu, ) def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1232,6 +1358,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + len(running_scheduled.swapped_out)), + kv_store_block_mapping_from_cpu= \ + prefills.kv_store_block_mapping_from_cpu, ) def _schedule(self) -> SchedulerOutputs: @@ -1481,6 +1609,13 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + + def _set_running(self, seq_group: SequenceGroup) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + def _append_slots(self, seq_group: SequenceGroup, blocks_to_copy: List[Tuple[int, int]], diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4aa0eebd976c9..59543f9fc5c13 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,7 @@ class EngineArgs: use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB + kv_store_space: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 @@ -454,6 +455,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=float, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU.') + parser.add_argument('--kv-store-space', + type=float, + default=EngineArgs.kv_store_space, + help='KV store space size (GiB) per GPU.') parser.add_argument( '--cpu-offload-gb', type=float, @@ -1022,6 +1027,7 @@ def create_engine_config(self, gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, + kv_store_space=self.kv_store_space, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7911dc8d04500..38e6483bb1e18 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -61,6 +61,7 @@ usage_message) from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION +from vllm.store.kv_store import KVBlockStoreManager logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -332,6 +333,14 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_processor = input_registry.create_input_processor( self.model_config) + self.kv_store_manager: Optional[KVBlockStoreManager] = None + if (self.cache_config.enable_kv_store): + kv_store_manager = KVBlockStoreManager.from_configs( + self.cache_config, self.model_config, self.parallel_config) + self.kv_store_manager = kv_store_manager + self.cache_config.kv_store_manager = kv_store_manager + + self.model_executor = executor_class(vllm_config=vllm_config, ) if self.model_config.task != "embedding": @@ -410,7 +419,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, + self.scheduler_config, self.cache_config, + self.kv_store_manager, + self.lora_config, self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None) @@ -1023,9 +1034,9 @@ def _update_num_computed_tokens_for_multi_step_prefill( This function updates num_computed_tokens for prompt sequences when Multi-Step is enabled. - seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group: SequenceGroup to update the num_computed_tokens for. seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - + is_first_step_output: Optional[bool] - When available, is_first_step_output indicates if the appended output token is the output of the first-step in multi-step. A value of None indicates that outputs from all steps in @@ -1437,7 +1448,9 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: finished_requests_ids=finished_requests_ids, # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) + last_sampled_token_ids=last_sampled_token_ids, + kv_store_block_mapping_from_cpu=\ + scheduler_outputs.kv_store_block_mapping_from_cpu,) if allow_async_output_proc: execute_model_req.async_callback = self.async_callbacks[ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a25c401b4ea10..de92aea373685 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -95,6 +95,8 @@ class LLM: the model weights. This virtually increases the GPU memory space you can use to hold the model weights, at the cost of CPU-GPU data transfer for every forward pass. + kv_store_space: The size (GiB) of CPU memory per GPU to use as KV store + space. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7d4cc4b69e614..ec7e975f34e16 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -313,6 +313,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() + self.enable_prefix_caching = cache_config.enable_prefix_caching + self.kv_store = cache_config.kv_store + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -335,15 +338,33 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if (self.kv_store != None) and \ + (self.kv_store.batch_layers_to_GPU == True): + self.kv_store.get_stream_sync( + attn_metadata.kv_store_meta.request_ids) + for i in range(self.start_layer, self.end_layer): + layer_id = (i - self.start_layer) + if (self.kv_store is not None) and \ + (self.kv_store.batch_layers_to_GPU == False): + self.kv_store.get_stream_layer_sync( + layer_id, attn_metadata.kv_store_meta.request_ids) layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], + kv_caches[layer_id], attn_metadata, residual, ) + + if (self.kv_store is not None): + self.kv_store.put_block_layer( + attn_metadata.kv_store_meta.incomplete_put_block_ids, + attn_metadata.kv_store_meta.put_block_ids_mapping, + layer_id, kv_caches[layer_id], + torch.cuda.current_stream()) + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..88e07d770984b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1266,6 +1266,10 @@ class ExecuteModelRequest( # Async callback async_callback: Optional[Callable] = None + + # for kv store + kv_store_block_mapping_from_cpu: Optional["BlockMappingFromCPU"] = None + @property def is_first_multi_step(self) -> bool: # TODO(will) make this be able to handle batches with variable number of diff --git a/vllm/store/__init__.py b/vllm/store/__init__.py new file mode 100644 index 0000000000000..6d6f5beca8aae --- /dev/null +++ b/vllm/store/__init__.py @@ -0,0 +1,7 @@ +from vllm.store.kv_store import KVBlockStore, KVBlockStoreManager, KVStoreMeta + +__all__ = [ + "KVBlockStore", + "KVBlockStoreManager", + "KVStoreMeta" +] diff --git a/vllm/store/kv_store.py b/vllm/store/kv_store.py new file mode 100644 index 0000000000000..85fe36ca3005d --- /dev/null +++ b/vllm/store/kv_store.py @@ -0,0 +1,482 @@ +import torch +import sys +from collections import deque +from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, + is_pin_memory_available) +from dataclasses import dataclass +from vllm import _custom_ops as ops +import time +from vllm.logger import init_logger +from typing import Optional, Union + +logger = init_logger(__name__) + +default_mem_size = 4 * 1024 * 1024 * 1024 +batch_layers_transmission_to_GPU = False + +@dataclass +class BlockMappingFromCPU: + block_mapping: torch.Tensor # 2-D tenso + block_offset: torch.Tensor # 1-D tensor, like offset array in CSR format + # the offset of each request in block_mapping + request_ids: torch.Tensor # request IDs + def __init__(self, + block_mapping: list[list[int, int]], + block_offset: list[int], + request_ids: list[int]): + self.block_mapping = torch.tensor(block_mapping, + device="cpu", + dtype=torch.int64).view(-1, 2) + self.block_offset = torch.tensor(block_offset, + device="cpu", + dtype=torch.int64).view(-1) + self.request_ids = torch.tensor(request_ids, + device="cpu", + dtype=torch.int64).view(-1) + def __str__(self): + return "block_mapping: " + str(self.block_mapping) + \ + " block_offset: " + str(self.block_offset) + \ + " request_ids: " + str(self.request_ids) + +@dataclass +class KVStoreMeta: + incomplete_put_block_ids: torch.Tensor # 4-D tensor: + # vllm_block_id, + # start_offset,end_offset, + # store_block_id + put_block_ids_mapping: torch.Tensor # 2-D tensor: + # vllm_block_id, store_block_id + request_ids: torch.Tensor # 1-D tensor + + @staticmethod + def null(): + return KVStoreMeta(torch.Tensor(), + torch.Tensor(), + torch.Tensor()) + + def __str__(self): + return "incomplete_put_block_ids: " + str(self.incomplete_put_block_ids) + \ + " put_block_ids_mapping: " + str(self.put_block_ids_mapping) + \ + " request_ids: " + str(self.request_ids) + +class BlockCount: + def __init__(self, block_id, access_count, last_access, block_hash, + send_flag = False): + # XXX: can remove it + self.block_id = block_id + self.access_count = access_count + self.last_access = last_access + self.block_hash = block_hash + self.send_flag = send_flag + def __str__(self): + return "block_id: " + str(self.block_id) + \ + " access_count: " + str(self.access_count) + \ + " last_access: " + str(self.last_access) + \ + " block_hash: " + str(self.block_hash) + +class KVBlockStoreManager: + def __init__(self, + block_head_mem_size: int, # size of each block for key/value + num_layer: int, + num_block_slot: int, # number of slots for each block + mem_size : int = default_mem_size, # total memory size + ): + + t = 2 * num_layer * block_head_mem_size + mem_size = (mem_size // t) * t + self.num_block_slot = num_block_slot + self.num_blocks = (mem_size // t) + self.time_cnt = 0 + self.block_cnt = 0 + self.block_table = [BlockCount(0, 0, 0, 0)] * self.num_blocks + self.hash_block_map: dict[int, int] = {} # hash -> store_block_id + self.gpu_and_store_block_map: dict[int, int] = \ + {} # gpu_block_id -> store_block_id + logger.info("KVBlockStore use %f GB memory per worker, " + "%d blocks, block size = %d", + mem_size / 1024 / 1024 / 1024, + self.num_blocks, + self.num_block_slot) + self.is_prefill = True + + @classmethod + def from_configs(cls, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig): + dtype = None + if (cache_config.cache_dtype == "auto"): + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + block_size = cache_config.block_size + num_key_value_heads = model_config.get_num_kv_heads(parallel_config) + head_dim = model_config.get_head_size() + num_layers = model_config.get_num_layers(parallel_config) + + block_head_mem_size = (dtype.itemsize * block_size + * num_key_value_heads * head_dim) + return cls(block_head_mem_size, + num_layers, + block_size, + cache_config.kv_store_space_bytes) + + + # allocate a logical block in CPU, and map the GPU block to kv store block + def allocate_block(self, block_hash, gpu_block_id) -> BlockCount: + if (self.is_prefill == False): + return None + ret_block_id = self.block_cnt + if (self.block_cnt == self.num_blocks): + # XXX: use policy to evict + # least used and earliest block + min_access_count = sys.maxsize + for block_count in self.block_table: + if block_count.access_count < min_access_count: + min_access_count = block_count.access_count + min_access_blocks = [] + for block_count in self.block_table: + if block_count.access_count == min_access_count: + min_access_blocks.append(block_count.block_id) + min_last_access = sys.maxsize + final_block_id = -1 + for store_block_id in min_access_blocks: + block_count = self.block_table[store_block_id] + if block_count.last_access < min_last_access: + min_last_access = block_count.last_access + final_block_id = store_block_id + assert(final_block_id != -1) + ret_block_id = final_block_id + final_block_count = self.block_table[ret_block_id] + # print("evict block_id: ", final_block_id) + # print("block count data -> ", final_block_count) + del self.hash_block_map[final_block_count.block_hash] + else: + self.block_cnt += 1 + self.hash_block_map[block_hash] = ret_block_id + self.block_table[ret_block_id] = BlockCount(ret_block_id, 1, + self.time_cnt, block_hash) + self.gpu_and_store_block_map[gpu_block_id] = ret_block_id + self.time_cnt += 1 + return self.block_table[ret_block_id] + + def has(self, block_hash: int) -> bool: + if (self.is_prefill == False): + return False + return (block_hash != -1) and \ + (self.hash_block_map.get(block_hash) != None) + + def remap_block_id(self, + block_hash: int, + vllm_block_id: int): + if (self.is_prefill == False): + return + assert(self.hash_block_map.get(block_hash) != None) + store_block_id = self.hash_block_map[block_hash] + self.gpu_and_store_block_map[vllm_block_id] = store_block_id + + def open_send_flag(self, block_id: int): + if (self.is_prefill == False): + return + store_block_id = self.gpu_and_store_block_map[block_id] + self.block_table[store_block_id].send_flag = True + self.block_table[store_block_id].access_count += 1 + self.block_table[store_block_id].last_access = self.time_cnt + self.time_cnt += 1 + + def close_send_flags(self, + vllm_block_ids): + if (len(vllm_block_ids) == 0): + return + # print("vllm_block_ids: ", vllm_block_ids) + for block_id in vllm_block_ids: + store_block_id = self.gpu_and_store_block_map[block_id] + self.block_table[store_block_id].send_flag = False + + def get_put_blocks_mapping(self, + incomplete_ids: torch.Tensor, + block_ids: torch.Tensor) \ + -> (torch.Tensor, torch.Tensor): + if (self.is_prefill == False) or \ + ((incomplete_ids.numel() == 0) and (block_ids.numel() == 0)): + return torch.Tensor(), torch.Tensor() + assert(incomplete_ids.is_cuda == False) + assert(block_ids.is_cuda == False) + # Note: the self.num_block_slot is equal to the vllm block size + incomplete_ids_numpy = incomplete_ids.numpy() + block_ids_numpy = block_ids.numpy() + incomplete_store_ids = torch.empty([incomplete_ids_numpy.shape[0]], + dtype=incomplete_ids.dtype) + store_block_ids = torch.empty([block_ids_numpy.shape[0]], + dtype=block_ids.dtype) + incomplete_store_ids_cpu = incomplete_store_ids.numpy() + store_block_ids_cpu = store_block_ids.numpy() + for i, incomplete_id in enumerate(incomplete_ids_numpy): + store_block_id = self.gpu_and_store_block_map[incomplete_id[0]] + incomplete_store_ids_cpu[i] = store_block_id + for i, block_id in enumerate(block_ids_numpy): + store_block_id = self.gpu_and_store_block_map[block_id] + store_block_ids_cpu[i] = store_block_id + + # XXX: need to pre-allocate the another dimension in attn_meta builder? + return (torch.cat((incomplete_ids, + incomplete_store_ids.view( + incomplete_store_ids.shape[0], 1)), + dim=1), + torch.stack((block_ids, store_block_ids), dim=1)) + + def get_block_mapping_from_torch(self, vllm_block_ids: torch.Tensor) \ + -> torch.Tensor: + if (self.is_prefill == False) or \ + (vllm_block_ids.numel() == 0): + return torch.Tensor() + ret_block_ids = torch.empty(vllm_block_ids.view(-1).shape, + dtype=vllm_block_ids.dtype) + ret_vllm_block_ids = torch.empty(vllm_block_ids.view(-1).shape, + dtype=vllm_block_ids.dtype) + ret_block_ids_cpu = ret_block_ids.view(-1).numpy() + ret_vllm_block_ids_cpu = ret_vllm_block_ids.view(-1).numpy() + cnt = 0 + for i, vllm_block_id in \ + enumerate(vllm_block_ids.view(-1).cpu().numpy()): + assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + store_block_id = self.gpu_and_store_block_map[vllm_block_id] + if (self.block_table[store_block_id].send_flag): + ret_block_ids_cpu[cnt] = store_block_id + ret_vllm_block_ids_cpu[cnt] = vllm_block_id + cnt += 1 + ret_block_ids.resize_([cnt]) + ret_vllm_block_ids.resize_([cnt]) + ret_tensor = torch.stack((ret_block_ids, ret_vllm_block_ids), dim=1) + return ret_tensor + + def get_block_mapping_from_python(self, vllm_block_ids: list[int]) \ + -> list[tuple[int, int]]: + if (self.is_prefill == False) or \ + (len(vllm_block_ids) == 0): + return [] + ret = [] + for vllm_block_id in vllm_block_ids: + assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + store_block_id = self.gpu_and_store_block_map[vllm_block_id] + if (self.block_table[store_block_id].send_flag): + ret.append([store_block_id, vllm_block_id]) + return ret + + def update_hash(self, old_hash: int, new_hash: int): + if (self.is_prefill == False): + return + assert(self.hash_block_map.get(old_hash) != None) + store_block_id = self.hash_block_map[old_hash] + del self.hash_block_map[old_hash] + self.hash_block_map[new_hash] = store_block_id + self.block_table[store_block_id].block_hash = new_hash + + # used to add a block_hash mapping when turn mutable + # to immutable in BlockManager v2 + def add_hash_map(self, block_hash: int, vllm_block_id: int): + if (self.is_prefill == False): + return + assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + store_block_id = self.gpu_and_store_block_map[vllm_block_id] + self.hash_block_map[block_hash] = store_block_id + self.block_table[store_block_id].block_hash = block_hash + +class EventPool: + def __init__(self, + reserve_num_requests: int, + num_layers: int, + device: torch.device): + self.reserve_num_requests = reserve_num_requests + self.num_layers = num_layers + self.event_queue: deque[torch.cuda.Event] = deque() + self.device = device + with torch.cuda.device(device): + for i in range(reserve_num_requests): + event = torch.cuda.Event() + # create the detail new event + event.record() + event.synchronize() + self.event_queue.append(event) + + def get_event(self) -> torch.cuda.Event: + if (len(self.event_queue) == 0): + with torch.cuda.device(self.device): + event = torch.cuda.Event() + # create the detail new event + event.record() + event.synchronize() + self.event_queue.append(event) + return self.event_queue.popleft() + + def put_event(self, event: torch.cuda.Event): + self.event_queue.append(event) + + def get_events(self, num_events: int) -> list[torch.cuda.Event]: + ret = [] + for i in range(num_events): + ret.append(self.get_event()) + return ret + + def put_events(self, events: list[torch.cuda.Event]): + for event in events: + self.event_queue.append(event) + +class KVBlockStore: + def __init__(self, + block_head_mem_size: int, # size of each block for key/value + num_layer: int, + num_block_slot: int, # number of slots for each block + data_type : torch.dtype, + device: torch.device, + mem_size : int = default_mem_size, # total memory size + ): + + t = 2 * num_layer * block_head_mem_size + mem_size = (mem_size // t) * t + assert(mem_size % (2 * num_layer * block_head_mem_size) == 0) + assert(block_head_mem_size % data_type.itemsize == 0) + num_item = (block_head_mem_size // data_type.itemsize // num_block_slot) + self.block_head_mem_size = block_head_mem_size + self.num_block_slot = num_block_slot + self.num_blocks = (mem_size // t) + self.num_item = num_item + self.device = device + self.num_layer = num_layer + self.event_map: dict[int, + Optional[torch.cuda.Event, + list[torch.cuda.Event]]] = {} + self.batch_layers_to_GPU = batch_layers_transmission_to_GPU + with torch.cuda.device(device): + self.store = torch.empty([self.num_blocks, + 2, + num_layer, + num_block_slot, + num_item], + dtype=data_type, + device="cpu").pin_memory() + self.get_stream = torch.cuda.Stream() + self.put_stream = torch.cuda.Stream() + self.put_event = torch.cuda.Event() + self.event_pool = EventPool(100, num_layer, device) + + @classmethod + def from_configs(cls, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device: torch.device): + dtype = None + if (cache_config.cache_dtype == "auto"): + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + block_size = cache_config.block_size + num_key_value_heads = model_config.get_num_kv_heads(parallel_config) + head_dim = model_config.get_head_size() + num_layers = model_config.get_num_layers(parallel_config) + + block_head_mem_size = (dtype.itemsize * block_size + * num_key_value_heads * head_dim) + return cls(block_head_mem_size, + num_layers, + block_size, + dtype, + device, + cache_config.kv_store_space_bytes) + + def put_block_layer(self, + incomplete_block_ids: torch.Tensor, + block_ids_mapping: torch.Tensor, + layer_id: int, + kv_cache: torch.Tensor, + forward_stream: torch.cuda.Stream): + if (incomplete_block_ids.numel() == 0) and \ + (block_ids_mapping.numel() == 0): + return + assert(incomplete_block_ids.is_cuda == False) + assert(block_ids_mapping.is_cuda == False) + incomplete_block_ids_numpy = incomplete_block_ids.numpy() + self.put_event.record(forward_stream) + self.put_event.wait(self.put_stream) + if (block_ids_mapping.numel() != 0): + with torch.cuda.stream(self.put_stream): + ops.kv_store_copy_blocks2CPU( + kv_cache, self.store, layer_id, + block_ids_mapping) + if (incomplete_block_ids.numel() != 0): + with torch.cuda.stream(self.put_stream): + ops.kv_store_copy_incomplete_blocks(kv_cache, self.store, + layer_id, + incomplete_block_ids) + + def get_blocks(self, + block_mapping_from_cpu: BlockMappingFromCPU, + kv_caches: list[torch.Tensor]): + block_mapping_tensor = block_mapping_from_cpu.block_mapping + block_offset_tensor = block_mapping_from_cpu.block_offset + request_ids_tensor = block_mapping_from_cpu.request_ids + request_ids_numpy = block_mapping_from_cpu.request_ids.numpy() + if (block_mapping_tensor.numel() == 0) or \ + (len(request_ids_numpy) == 0): + return + is_batch_layer = self.batch_layers_to_GPU + event_list = [] + if (is_batch_layer): + # if batch layer, we need to allocate one event for each request + request_last_events = [] + for idx, req_id in enumerate(request_ids_numpy): + event = self.event_pool.get_event() + self.event_map[req_id] = event + event_list.append(event) + else: + # if not batch layer, we need to allocate the events for each layer + for req_id in request_ids_numpy: + event_list_tmp = self.event_pool.get_events(self.num_layer) + self.event_map[req_id] = event_list_tmp + event_list.extend(event_list_tmp) + with torch.cuda.stream(self.get_stream): + ops.kv_store_copy_blocks2GPU( + self.store, kv_caches, + self.num_layer, + block_mapping_tensor, + block_offset_tensor, + request_ids_tensor, + [event.cuda_event for event in event_list], + is_batch_layer) + + # pair used with get_blocks_batch + def get_stream_sync(self, request_ids: torch.Tensor): + if (request_ids.numel() == 0): + return + for req_id in request_ids.numpy(): + if (self.event_map.get(req_id) == None): + continue + event = self.event_map[req_id] + event.synchronize() + # recycle the events + self.event_pool.put_event(event) + del self.event_map[req_id] + + # pair used with get_layer_blocks/get_blocks + def get_stream_layer_sync(self, + layer_id: int, + request_ids: torch.Tensor): + if (request_ids.numel() == 0): + return + for req_id in request_ids.numpy(): + if (self.event_map.get(req_id) == None): + continue + self.event_map[req_id][layer_id].synchronize() + if (layer_id == self.num_layer - 1): + # recycle the events + for req_id in request_ids.numpy(): + if (self.event_map.get(req_id) == None): + continue + event_list = self.event_map[req_id] + self.event_pool.put_events(event_list) + del self.event_map[req_id] + + def put_stream_sync(self): + self.put_stream.synchronize() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c9f06eef3f907..e30981ec453c5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -466,6 +466,7 @@ def __init__(self, self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.finished_requests_ids = finished_requests_ids self.decode_only = True + self.enable_kv_store = self.runner.enable_kv_store # Intermediate data (data in CPU before going to GPU) for # the current sequence group. @@ -1015,6 +1016,7 @@ def __init__( self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states + self.enable_kv_store = cache_config.enable_kv_store self.device = self.device_config.device self.pin_memory = is_pin_memory_available() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 094dd5a5d08b3..5e0417f5271cc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -28,10 +28,11 @@ from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) +from vllm.utils import init_logger +from vllm.store.kv_store import KVBlockStore, KVStoreMeta logger = init_logger(__name__) - class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -55,6 +56,14 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker + + if (self.cache_config.enable_kv_store): + self.cache_config.kv_store = KVBlockStore.from_configs( + self.cache_config, self.model_config, self.parallel_config, + torch.device(f"cuda:{self.local_rank}")) + self.kv_store = self.cache_config.kv_store + self.kv_store_manager = self.cache_config.kv_store_manager + if is_driver_worker: assert rank % self.parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." @@ -112,6 +121,37 @@ def __init__( else: self.profiler = None + def prepare_kv_store_meta(self, + is_prefill: bool, + incomplete_put_block_ids: torch.Tensor, + put_block_ids: torch.Tensor, + seq_g_list: List[SequenceGroupMetadata]) \ + -> KVStoreMeta: + ret_incomplete_put_blocks = torch.Tensor() + ret_put_blocks_mapping = torch.Tensor() + ret_seq_g_ids = torch.Tensor() + if (self.local_rank == 0) and (self.kv_store_manager is not None): + self.kv_store_manager.is_prefill = is_prefill + (ret_incomplete_put_blocks, ret_put_blocks_mapping) = \ + self.kv_store_manager.get_put_blocks_mapping( + incomplete_put_block_ids, put_block_ids) + self.kv_store_manager.is_prefill = False + if (is_prefill) and (ret_incomplete_put_blocks.numel() + \ + ret_put_blocks_mapping.numel() > 0): + # XXX: use first seq_id representing the seq_group id + seq_g_ids = [seq_g.get_first_seq_id() for seq_g in seq_g_list] + ret_seq_g_ids = torch.tensor(seq_g_ids, + device="cpu", + dtype=torch.int64).view(-1) + return KVStoreMeta(ret_incomplete_put_blocks, + ret_put_blocks_mapping, + ret_seq_g_ids) + + def put_stream_sync(self): + if (self.kv_store is not None): + self.kv_store.put_stream_sync() + + def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -144,6 +184,11 @@ def init_device(self) -> None: else: raise RuntimeError( f"Not support device type: {self.device_config.device}") + + # use higher priority stream for forward step + compute_stream = torch.cuda.Stream(priority=-10) + torch.cuda.set_stream(compute_stream) + # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, @@ -341,6 +386,8 @@ def prepare_worker_input( blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, + kv_store_block_mapping_from_cpu=\ + execute_model_req.kv_store_block_mapping_from_cpu, ) @torch.inference_mode() @@ -359,6 +406,18 @@ def execute_worker(self, worker_input: WorkerInput) -> None: and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + @torch.inference_mode() + def issue_blocks_copy(self, worker_input: WorkerInput) -> None: + if (self.kv_store == None): + return + kv_store_block_mapping_from_cpu = \ + worker_input.kv_store_block_mapping_from_cpu + kv_caches = (self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None) + self.kv_store.get_blocks( + worker_input.kv_store_block_mapping_from_cpu, + kv_caches) + def _get_cached_seq_group_metadata( self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7c0bc5a678956..218919511553a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -12,12 +12,14 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SequenceGroupMetadata) from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) +from vllm.store.kv_store import KVStoreMeta,BlockMappingFromCPU logger = init_logger(__name__) @@ -148,6 +150,7 @@ class WorkerInput: blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 num_steps: int = 1 + kv_store_block_mapping_from_cpu: Optional[BlockMappingFromCPU] = None @classmethod def from_broadcasted_tensor_dict( @@ -165,6 +168,11 @@ def from_broadcasted_tensor_dict( blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], num_steps=tensor_dict.pop("num_steps"), + kv_store_block_mapping_from_cpu=BlockMappingFromCPU( + tensor_dict.pop("kv_block_mapping"), + tensor_dict.pop("kv_block_mapping_offsets"), + tensor_dict.pop("kv_block_mapping_req_ids") + ) ) def as_broadcastable_tensor_dict( @@ -179,6 +187,12 @@ def as_broadcastable_tensor_dict( "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, + "kv_block_mapping": \ + self.kv_store_block_mapping_from_cpu.block_mapping, + "kv_block_mapping_offsets": \ + self.kv_store_block_mapping_from_cpu.block_offset, + "kv_block_mapping_req_ids": \ + self.kv_store_block_mapping_from_cpu.request_ids } return tensor_dict @@ -230,6 +244,34 @@ def prepare_worker_input( """ raise NotImplementedError + @abstractmethod + def prepare_kv_store_meta(self, + is_prefill: bool, + incomplete_put_block_ids: torch.Tensor, + put_block_ids: torch.Tensor, + seq_g_list: List[SequenceGroupMetadata]) \ + -> KVStoreMeta: + """ + Prepare the KVStoreMeta for the worker. This is called by the driver + worker to prepare the metadata for the KVStore. + """ + pass + + @abstractmethod + def put_stream_sync(self) -> None: + """ + Synchronize the stream. + """ + raise NotImplementedError + + @abstractmethod + def issue_blocks_copy(self, worker_input: WorkerInput) -> None: + """ + Issue the copy of the blocks from CPU to GPU with the given indices. + """ + raise NotImplementedError + + @abstractmethod def execute_worker(self, worker_input: WorkerInput) -> None: """ @@ -271,6 +313,16 @@ def _get_driver_input_and_broadcast( execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + incomplete_put_block_ids = \ + model_input.attn_metadata.kv_store_meta.incomplete_put_block_ids + put_block_ids = \ + model_input.attn_metadata.kv_store_meta.put_block_ids_mapping + model_input.attn_metadata.kv_store_meta = \ + self.prepare_kv_store_meta(model_input.is_prompt, + incomplete_put_block_ids, + put_block_ids, + execute_model_req.seq_group_metadata_list) + kwargs = extract_previous_hidden_states(execute_model_req) if self.do_metadata_broadcast: @@ -325,6 +377,8 @@ def execute_model( self.execute_worker(worker_input) + self.issue_blocks_copy(worker_input) + # If there is no input, we don't need to execute the model. if worker_input.num_seq_groups == 0: return [] @@ -349,6 +403,7 @@ def execute_model( **kwargs, ) + self.put_stream_sync() model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors @@ -475,7 +530,7 @@ def extract_previous_hidden_states( data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ Dict[str, torch.Tensor]: """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following + which can be used directly as additional kwargs in any following execute_model calls. This is used in draft models like EAGLE.""" output = {} From 8613037db14ada84ea68ce000d58de155cc9d6fa Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Tue, 3 Dec 2024 21:21:29 +0800 Subject: [PATCH 02/11] Fix(kv store): fix some building and running bugs Signed-off-by: Dahai Tang --- CMakeLists.txt | 1 + vllm/attention/backends/abstract.py | 2 +- vllm/core/scheduler.py | 5 +---- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..cb590fb190598 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,6 +187,7 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") set(VLLM_EXT_SRC "csrc/cache_kernels.cu" + "csrc/kv_store/kv_store.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/pos_encoding_kernels.cu" diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 35db3c21b2336..80b8dfed6e900 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -115,7 +115,7 @@ class AttentionMetadata: # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - kv_store_meta: KVStoreMeta = None + kv_store_meta: KVStoreMeta # The index maps that relate multi-modal embeddings to the corresponding # placeholders. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d48becd166e71..407715b57942c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -16,7 +16,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import Device, PyObjectCache, set_abort_request_id +from vllm.utils import Device, PyObjectCache from vllm.store.kv_store import KVBlockStoreManager,BlockMappingFromCPU logger = init_logger(__name__) @@ -947,8 +947,6 @@ def _stop_schedule_prefill(num_new_tokens_uncached, kv_store_tmp_queue : Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and kv_store_waiting_queue: - if budget.num_curr_prefill_seqs >= self.scheduler_config.max_num_prefill_seqs: - break seq_group = kv_store_waiting_queue[0] @@ -999,7 +997,6 @@ def _stop_schedule_prefill(num_new_tokens_uncached, num_cached_tokens=num_new_tokens_cached, ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) - budget.add_num_prefill_seqs(num_new_seqs) kv_store_waiting_queue.extendleft(kv_store_tmp_queue) while self._passed_delay(time.time()) and waiting_queue: From f47bbcea979660c9117738672aa1ef54ba851883 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 14:58:36 +0800 Subject: [PATCH 03/11] Refactor: pass lint check Signed-off-by: Dahai Tang --- vllm/attention/backends/abstract.py | 2 - vllm/attention/backends/flash_attn.py | 57 ++--- vllm/attention/backends/utils.py | 11 +- vllm/core/block/cpu_gpu_block_allocator.py | 10 +- vllm/core/block/prefix_caching_block.py | 35 +-- vllm/core/block_manager.py | 8 +- vllm/engine/llm_engine.py | 8 +- vllm/sequence.py | 6 +- vllm/store/kv_store.py | 270 ++++++++++----------- vllm/worker/worker.py | 34 +-- vllm/worker/worker_base.py | 49 ++-- 11 files changed, 238 insertions(+), 252 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 80b8dfed6e900..aed04361e5fb4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) -from vllm.store.kv_store import KVStoreMeta import torch @@ -115,7 +114,6 @@ class AttentionMetadata: # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - kv_store_meta: KVStoreMeta # The index maps that relate multi-modal embeddings to the corresponding # placeholders. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 90e474e515360..35685917d668c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -4,8 +4,8 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -import torch import numpy as np +import torch from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -17,8 +17,8 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.store.kv_store import KVStoreMeta from vllm.multimodal import MultiModalPlaceholderMap +from vllm.store.kv_store import KVStoreMeta from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -131,6 +131,9 @@ class FlashAttentionMetadata(AttentionMetadata): # captured. block_tables: Optional[torch.Tensor] + # CPU KV store metadata + kv_store_meta: Optional[KVStoreMeta] + # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. @@ -453,11 +456,11 @@ def _add_seq_group( context_len, self.sliding_window) (range_start, range_end, block_table) = compute_slot_mapping( - is_profile_run, self.slot_mapping, seq_id, - curr_seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - self.range_list.append((range_start, range_end, - block_table, is_prompt)) + is_profile_run, self.slot_mapping, seq_id, curr_seq_len, + context_len, start_idx, self.block_size, + inter_data.block_tables) + self.range_list.append( + (range_start, range_end, block_table, is_prompt)) def _get_graph_runner_block_tables( self, num_seqs: int, @@ -481,7 +484,7 @@ def _get_graph_runner_block_tables( i, :max_blocks] = block_table[:max_blocks] return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) + device=self.runner.device, non_blocking=True) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): @@ -554,11 +557,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], incomplete_put_block_ids = [] put_block_ids = [] - assert(len(self.range_list) == len(self.block_tables)) - if (self.enable_kv_store == True): + assert (len(self.range_list) == len(self.block_tables)) + if (self.enable_kv_store): for (range_start, range_end, seq_block_table, is_prompt) in \ self.range_list: - if (range_start == range_end) or (is_prompt == False): + if (range_start == range_end) or (not is_prompt): continue block_size = self.block_size range_end -= 1 @@ -567,31 +570,33 @@ def build(self, seq_lens: List[int], query_lens: List[int], range_start_block_offset = range_start % block_size range_end_block_offset = range_end % block_size + 1 if (range_start_block_id == range_end_block_id): - incomplete_put_block_ids.append( - [seq_block_table[range_start_block_id], - range_start_block_offset, range_end_block_offset]) + incomplete_put_block_ids.append([ + seq_block_table[range_start_block_id], + range_start_block_offset, range_end_block_offset + ]) else: if (range_start_block_offset == 0): put_block_ids.append( seq_block_table[range_start_block_id]) else: - incomplete_put_block_ids.append( - [seq_block_table[range_start_block_id], - range_start_block_offset, block_size]) - put_block_ids.extend( - seq_block_table[ - range_start_block_id + 1:range_end_block_id]) + incomplete_put_block_ids.append([ + seq_block_table[range_start_block_id], + range_start_block_offset, block_size + ]) + put_block_ids.extend(seq_block_table[range_start_block_id + + 1:range_end_block_id]) if (range_end_block_offset == block_size): put_block_ids.append( seq_block_table[range_end_block_id]) else: - incomplete_put_block_ids.append( - [seq_block_table[range_end_block_id], - 0, range_end_block_offset]) + incomplete_put_block_ids.append([ + seq_block_table[range_end_block_id], 0, + range_end_block_offset + ]) incomplete_put_block_ids_numpy = np.array(incomplete_put_block_ids) put_block_ids_numpy = np.array(put_block_ids) incomplete_put_block_ids_cpu = torch.from_numpy( - incomplete_put_block_ids_numpy).to("cpu") + incomplete_put_block_ids_numpy).to("cpu") put_block_ids_cpu = torch.from_numpy(put_block_ids_numpy).to("cpu") return FlashAttentionMetadata( @@ -612,9 +617,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_store_meta=KVStoreMeta(incomplete_put_block_ids_cpu, - put_block_ids_cpu, - torch.Tensor()) - ) + put_block_ids_cpu, torch.Tensor())) class FlashAttentionImpl(AttentionImpl): diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 174b9b70fa771..1cd11e262748e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -11,8 +11,8 @@ AttentionState) from vllm.attention.backends.abstract import AttentionType from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.store.kv_store import KVStoreMeta +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -79,8 +79,7 @@ def _compute_slot_mapping_numpy(slot_mapping: List[int], def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, - block_tables: Dict[int, List[int]]) \ - -> (int, int, List[int]): + block_tables: Dict[int, List[int]]): """ Compute slot mapping. return: the range_start and range_end @@ -115,8 +114,8 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], else: _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, range_end, block_size) - block_table = block_table[ - (range_start // block_size):((range_end - 1) // block_size) + 1] + block_table = block_table[(range_start // + block_size):((range_end - 1) // block_size) + 1] start_offset = (range_start // block_size) * block_size range_start -= start_offset range_end -= start_offset @@ -335,7 +334,7 @@ def graph_capture_get_metadata_for_batch( seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], - kv_store_meta = KVStoreMeta.null(), + kv_store_meta=KVStoreMeta.null(), use_cuda_graph=True, ) if is_encoder_decoder_model: diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5d3751aa5070a..d950cf82b77eb 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -5,6 +5,7 @@ from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator from vllm.platforms import current_platform +from vllm.store.kv_store import KVBlockStoreManager from vllm.utils import Device @@ -26,7 +27,7 @@ def create( num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, - kv_store_manager: Optional["KVStoreManager"], + kv_store_manager: Optional["KVBlockStoreManager"], ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -89,6 +90,7 @@ def create( num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, + kv_store_manager=None, ) else: raise ValueError(f"Unknown allocator type {allocator_type=}") @@ -309,12 +311,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) - def mark_blocks_as_cached(self, blocks: List[Block]) -> None: - """Mark blocks as cached, only use for prefix caching with KV Store.""" - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].mark_blocks_as_cached(blocks) - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index b653f7042d1a1..6d6026e3bd5a5 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -2,8 +2,8 @@ import sys from bisect import bisect_left from os.path import commonprefix -from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) +from typing import (TYPE_CHECKING, Callable, Dict, FrozenSet, Iterable, List, + Optional, Set, Tuple) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -13,7 +13,9 @@ NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.sequence import Sequence -from vllm.store.kv_store import KVBlockStoreManager + +if TYPE_CHECKING: + from vllm.store.kv_store import KVBlockStoreManager # NOTE: for debug the kvstore, may cause some assert errors only_enable_cpu_kvstore = False @@ -69,8 +71,8 @@ def __init__( self, num_blocks: int, block_size: int, + kv_store_manager: Optional["KVBlockStoreManager"], block_ids: Optional[Iterable[int]] = None, - kv_store_manager: Optional[KVBlockStoreManager] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): if block_ids is None: @@ -170,8 +172,8 @@ def allocate_immutable_block(self, physical_block_id=None) assert block.content_hash is not None - if (only_enable_cpu_kvstore == False) or \ - (self.kv_store_manager == None): + if (not only_enable_cpu_kvstore) or \ + (self.kv_store_manager is None): cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: self.metric_data.query(hit=True) @@ -179,14 +181,14 @@ def allocate_immutable_block(self, self._incr_refcount_cached_block(block) return block self.metric_data.query(hit=False) - if (self.kv_store_manager != None) and \ + if (self.kv_store_manager is not None) and \ (self.kv_store_manager.has(block.content_hash)): block_id = self._allocate_block_id() block.block_id = block_id # print("Found in kv store with hash: ", block.content_hash, # "block_id: ", block_id) - self.kv_store_manager.remap_block_id( - block.content_hash, block.block_id) + self.kv_store_manager.remap_block_id(block.content_hash, + block.block_id) self.kv_store_manager.open_send_flag(block.block_id) self._cached_blocks[block.content_hash] = block_id block.computed = True @@ -236,7 +238,7 @@ def allocate_mutable_block(self, physical_block_id=block_id) assert not block.computed assert block.content_hash is None - if (self.kv_store_manager != None): + if (self.kv_store_manager is not None): self.kv_store_manager.allocate_block(-1, block.block_id) return block @@ -470,9 +472,9 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: assert self._refcounter.get(block.block_id) > 0 if block.content_hash not in self._cached_blocks: - if (self.kv_store_manager != None): - self.kv_store_manager.add_hash_map( - block.content_hash, block.block_id) + if (self.kv_store_manager is not None): + self.kv_store_manager.add_hash_map(block.content_hash, + block.block_id) # No cached content hash => Set this block as cached. # Note that this block cannot be marked as computed yet @@ -493,10 +495,9 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: # Note that in this case, the block is marked as computed self._incr_refcount_cached_block(block) - if (self.kv_store_manager != None): - self.kv_store_manager.add_hash_map( - block.content_hash, block.block_id) - + if (self.kv_store_manager is not None): + self.kv_store_manager.add_hash_map(block.content_hash, + block.block_id) return block.block_id diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 27a214f359c29..09a51a1ef7c02 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -1,5 +1,5 @@ """A block manager that manages token blocks.""" -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple @@ -12,7 +12,9 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from vllm.store.kv_store import KVBlockStoreManager + +if TYPE_CHECKING: + from vllm.store.kv_store import KVBlockStoreManager SeqId = int EncoderSeqId = str @@ -63,7 +65,7 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, - kv_store_manager : KVBlockStoreManager, + kv_store_manager: Optional["KVBlockStoreManager"], watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 38e6483bb1e18..3ec7999940162 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -50,6 +50,7 @@ SequenceGroup, SequenceGroupBase, SequenceGroupMetadata, SequenceGroupOutput, SequenceStatus) +from vllm.store.kv_store import KVBlockStoreManager from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -61,7 +62,6 @@ usage_message) from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION -from vllm.store.kv_store import KVBlockStoreManager logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -336,11 +336,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.kv_store_manager: Optional[KVBlockStoreManager] = None if (self.cache_config.enable_kv_store): kv_store_manager = KVBlockStoreManager.from_configs( - self.cache_config, self.model_config, self.parallel_config) + self.cache_config, self.model_config, self.parallel_config) self.kv_store_manager = kv_store_manager self.cache_config.kv_store_manager = kv_store_manager - self.model_executor = executor_class(vllm_config=vllm_config, ) if self.model_config.task != "embedding": @@ -420,8 +419,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.scheduler = [ Scheduler( self.scheduler_config, self.cache_config, - self.kv_store_manager, - self.lora_config, + self.kv_store_manager, self.lora_config, self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None) diff --git a/vllm/sequence.py b/vllm/sequence.py index 88e07d770984b..2d06b5ddbc51a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,7 +8,7 @@ from functools import reduce from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import TYPE_CHECKING, Set, Tuple, Union import msgspec import torch @@ -20,6 +20,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +if TYPE_CHECKING: + from vllm.store.kv_store import BlockMappingFromCPU + VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 @@ -1266,7 +1269,6 @@ class ExecuteModelRequest( # Async callback async_callback: Optional[Callable] = None - # for kv store kv_store_block_mapping_from_cpu: Optional["BlockMappingFromCPU"] = None diff --git a/vllm/store/kv_store.py b/vllm/store/kv_store.py index 85fe36ca3005d..2107cd62bc1d3 100644 --- a/vllm/store/kv_store.py +++ b/vllm/store/kv_store.py @@ -1,30 +1,32 @@ -import torch import sys from collections import deque -from vllm.config import CacheConfig, ModelConfig, ParallelConfig -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, - is_pin_memory_available) from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + from vllm import _custom_ops as ops -import time from vllm.logger import init_logger -from typing import Optional, Union +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +if TYPE_CHECKING: + from vllm.config import CacheConfig, ModelConfig, ParallelConfig logger = init_logger(__name__) default_mem_size = 4 * 1024 * 1024 * 1024 batch_layers_transmission_to_GPU = False + @dataclass class BlockMappingFromCPU: block_mapping: torch.Tensor # 2-D tenso - block_offset: torch.Tensor # 1-D tensor, like offset array in CSR format - # the offset of each request in block_mapping - request_ids: torch.Tensor # request IDs - def __init__(self, - block_mapping: list[list[int, int]], - block_offset: list[int], - request_ids: list[int]): + block_offset: torch.Tensor # 1-D tensor, like offset array in CSR format + # the offset of each request in block_mapping + request_ids: torch.Tensor # request IDs + + def __init__(self, block_mapping: list[list[int, int]], + block_offset: list[int], request_ids: list[int]): self.block_mapping = torch.tensor(block_mapping, device="cpu", dtype=torch.int64).view(-1, 2) @@ -34,54 +36,65 @@ def __init__(self, self.request_ids = torch.tensor(request_ids, device="cpu", dtype=torch.int64).view(-1) + def __str__(self): return "block_mapping: " + str(self.block_mapping) + \ " block_offset: " + str(self.block_offset) + \ " request_ids: " + str(self.request_ids) + @dataclass class KVStoreMeta: - incomplete_put_block_ids: torch.Tensor # 4-D tensor: - # vllm_block_id, - # start_offset,end_offset, - # store_block_id - put_block_ids_mapping: torch.Tensor # 2-D tensor: - # vllm_block_id, store_block_id - request_ids: torch.Tensor # 1-D tensor + incomplete_put_block_ids: torch.Tensor # 4-D tensor: + # vllm_block_id, + # start_offset,end_offset, + # store_block_id + put_block_ids_mapping: torch.Tensor # 2-D tensor: + # vllm_block_id, store_block_id + request_ids: torch.Tensor # 1-D tensor @staticmethod def null(): - return KVStoreMeta(torch.Tensor(), - torch.Tensor(), - torch.Tensor()) + return KVStoreMeta(torch.Tensor(), torch.Tensor(), torch.Tensor()) def __str__(self): - return "incomplete_put_block_ids: " + str(self.incomplete_put_block_ids) + \ + return "incomplete_put_block_ids: " + \ + str(self.incomplete_put_block_ids) + \ " put_block_ids_mapping: " + str(self.put_block_ids_mapping) + \ " request_ids: " + str(self.request_ids) + class BlockCount: - def __init__(self, block_id, access_count, last_access, block_hash, - send_flag = False): + + def __init__(self, + block_id, + access_count, + last_access, + block_hash, + send_flag=False): # XXX: can remove it self.block_id = block_id self.access_count = access_count self.last_access = last_access self.block_hash = block_hash self.send_flag = send_flag + def __str__(self): return "block_id: " + str(self.block_id) + \ " access_count: " + str(self.access_count) + \ " last_access: " + str(self.last_access) + \ " block_hash: " + str(self.block_hash) + class KVBlockStoreManager: - def __init__(self, - block_head_mem_size: int, # size of each block for key/value - num_layer: int, - num_block_slot: int, # number of slots for each block - mem_size : int = default_mem_size, # total memory size - ): + + def __init__( + self, + block_head_mem_size: int, # size of each block for key/value + num_layer: int, + num_block_slot: int, # number of slots for each block + mem_size: int = default_mem_size, # total memory size + ): t = 2 * num_layer * block_head_mem_size mem_size = (mem_size // t) * t @@ -90,21 +103,19 @@ def __init__(self, self.time_cnt = 0 self.block_cnt = 0 self.block_table = [BlockCount(0, 0, 0, 0)] * self.num_blocks - self.hash_block_map: dict[int, int] = {} # hash -> store_block_id + self.hash_block_map: dict[int, int] = {} # hash -> store_block_id self.gpu_and_store_block_map: dict[int, int] = \ {} # gpu_block_id -> store_block_id - logger.info("KVBlockStore use %f GB memory per worker, " - "%d blocks, block size = %d", - mem_size / 1024 / 1024 / 1024, - self.num_blocks, - self.num_block_slot) + logger.info( + "KVBlockStore use %f GB memory per worker, " + "%d blocks, block size = %d", mem_size / 1024 / 1024 / 1024, + self.num_blocks, self.num_block_slot) self.is_prefill = True @classmethod - def from_configs(cls, - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig): + def from_configs(cls, cache_config: Optional["CacheConfig"], + model_config: Optional["ModelConfig"], + parallel_config: Optional["ParallelConfig"]): dtype = None if (cache_config.cache_dtype == "auto"): dtype = model_config.dtype @@ -115,17 +126,14 @@ def from_configs(cls, head_dim = model_config.get_head_size() num_layers = model_config.get_num_layers(parallel_config) - block_head_mem_size = (dtype.itemsize * block_size - * num_key_value_heads * head_dim) - return cls(block_head_mem_size, - num_layers, - block_size, + block_head_mem_size = (dtype.itemsize * block_size * + num_key_value_heads * head_dim) + return cls(block_head_mem_size, num_layers, block_size, cache_config.kv_store_space_bytes) - # allocate a logical block in CPU, and map the GPU block to kv store block def allocate_block(self, block_hash, gpu_block_id) -> BlockCount: - if (self.is_prefill == False): + if (not self.is_prefill): return None ret_block_id = self.block_cnt if (self.block_cnt == self.num_blocks): @@ -146,7 +154,7 @@ def allocate_block(self, block_hash, gpu_block_id) -> BlockCount: if block_count.last_access < min_last_access: min_last_access = block_count.last_access final_block_id = store_block_id - assert(final_block_id != -1) + assert (final_block_id != -1) ret_block_id = final_block_id final_block_count = self.block_table[ret_block_id] # print("evict block_id: ", final_block_id) @@ -156,28 +164,26 @@ def allocate_block(self, block_hash, gpu_block_id) -> BlockCount: self.block_cnt += 1 self.hash_block_map[block_hash] = ret_block_id self.block_table[ret_block_id] = BlockCount(ret_block_id, 1, - self.time_cnt, block_hash) + self.time_cnt, block_hash) self.gpu_and_store_block_map[gpu_block_id] = ret_block_id self.time_cnt += 1 return self.block_table[ret_block_id] def has(self, block_hash: int) -> bool: - if (self.is_prefill == False): + if (not self.is_prefill): return False return (block_hash != -1) and \ - (self.hash_block_map.get(block_hash) != None) + (self.hash_block_map.get(block_hash) is not None) - def remap_block_id(self, - block_hash: int, - vllm_block_id: int): - if (self.is_prefill == False): + def remap_block_id(self, block_hash: int, vllm_block_id: int): + if (not self.is_prefill): return - assert(self.hash_block_map.get(block_hash) != None) + assert (self.hash_block_map.get(block_hash) is not None) store_block_id = self.hash_block_map[block_hash] self.gpu_and_store_block_map[vllm_block_id] = store_block_id def open_send_flag(self, block_id: int): - if (self.is_prefill == False): + if (not self.is_prefill): return store_block_id = self.gpu_and_store_block_map[block_id] self.block_table[store_block_id].send_flag = True @@ -185,8 +191,7 @@ def open_send_flag(self, block_id: int): self.block_table[store_block_id].last_access = self.time_cnt self.time_cnt += 1 - def close_send_flags(self, - vllm_block_ids): + def close_send_flags(self, vllm_block_ids): if (len(vllm_block_ids) == 0): return # print("vllm_block_ids: ", vllm_block_ids) @@ -198,11 +203,11 @@ def get_put_blocks_mapping(self, incomplete_ids: torch.Tensor, block_ids: torch.Tensor) \ -> (torch.Tensor, torch.Tensor): - if (self.is_prefill == False) or \ + if (not self.is_prefill) or \ ((incomplete_ids.numel() == 0) and (block_ids.numel() == 0)): return torch.Tensor(), torch.Tensor() - assert(incomplete_ids.is_cuda == False) - assert(block_ids.is_cuda == False) + assert (not incomplete_ids.is_cuda) + assert (not block_ids.is_cuda) # Note: the self.num_block_slot is equal to the vllm block size incomplete_ids_numpy = incomplete_ids.numpy() block_ids_numpy = block_ids.numpy() @@ -220,27 +225,27 @@ def get_put_blocks_mapping(self, store_block_ids_cpu[i] = store_block_id # XXX: need to pre-allocate the another dimension in attn_meta builder? - return (torch.cat((incomplete_ids, - incomplete_store_ids.view( - incomplete_store_ids.shape[0], 1)), - dim=1), - torch.stack((block_ids, store_block_ids), dim=1)) + return (torch.cat( + (incomplete_ids, + incomplete_store_ids.view(incomplete_store_ids.shape[0], 1)), + dim=1), torch.stack((block_ids, store_block_ids), dim=1)) def get_block_mapping_from_torch(self, vllm_block_ids: torch.Tensor) \ -> torch.Tensor: - if (self.is_prefill == False) or \ + if (not self.is_prefill) or \ (vllm_block_ids.numel() == 0): return torch.Tensor() ret_block_ids = torch.empty(vllm_block_ids.view(-1).shape, - dtype=vllm_block_ids.dtype) + dtype=vllm_block_ids.dtype) ret_vllm_block_ids = torch.empty(vllm_block_ids.view(-1).shape, - dtype=vllm_block_ids.dtype) + dtype=vllm_block_ids.dtype) ret_block_ids_cpu = ret_block_ids.view(-1).numpy() ret_vllm_block_ids_cpu = ret_vllm_block_ids.view(-1).numpy() cnt = 0 for i, vllm_block_id in \ enumerate(vllm_block_ids.view(-1).cpu().numpy()): - assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + assert (self.gpu_and_store_block_map.get(vllm_block_id) + is not None) store_block_id = self.gpu_and_store_block_map[vllm_block_id] if (self.block_table[store_block_id].send_flag): ret_block_ids_cpu[cnt] = store_block_id @@ -253,21 +258,22 @@ def get_block_mapping_from_torch(self, vllm_block_ids: torch.Tensor) \ def get_block_mapping_from_python(self, vllm_block_ids: list[int]) \ -> list[tuple[int, int]]: - if (self.is_prefill == False) or \ + if (not self.is_prefill) or \ (len(vllm_block_ids) == 0): return [] ret = [] for vllm_block_id in vllm_block_ids: - assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + assert (self.gpu_and_store_block_map.get(vllm_block_id) + is not None) store_block_id = self.gpu_and_store_block_map[vllm_block_id] if (self.block_table[store_block_id].send_flag): ret.append([store_block_id, vllm_block_id]) return ret def update_hash(self, old_hash: int, new_hash: int): - if (self.is_prefill == False): + if (not self.is_prefill): return - assert(self.hash_block_map.get(old_hash) != None) + assert (self.hash_block_map.get(old_hash) is not None) store_block_id = self.hash_block_map[old_hash] del self.hash_block_map[old_hash] self.hash_block_map[new_hash] = store_block_id @@ -276,17 +282,17 @@ def update_hash(self, old_hash: int, new_hash: int): # used to add a block_hash mapping when turn mutable # to immutable in BlockManager v2 def add_hash_map(self, block_hash: int, vllm_block_id: int): - if (self.is_prefill == False): + if (not self.is_prefill): return - assert(self.gpu_and_store_block_map.get(vllm_block_id) != None) + assert (self.gpu_and_store_block_map.get(vllm_block_id) is not None) store_block_id = self.gpu_and_store_block_map[vllm_block_id] self.hash_block_map[block_hash] = store_block_id self.block_table[store_block_id].block_hash = block_hash + class EventPool: - def __init__(self, - reserve_num_requests: int, - num_layers: int, + + def __init__(self, reserve_num_requests: int, num_layers: int, device: torch.device): self.reserve_num_requests = reserve_num_requests self.num_layers = num_layers @@ -323,49 +329,48 @@ def put_events(self, events: list[torch.cuda.Event]): for event in events: self.event_queue.append(event) + class KVBlockStore: - def __init__(self, - block_head_mem_size: int, # size of each block for key/value - num_layer: int, - num_block_slot: int, # number of slots for each block - data_type : torch.dtype, - device: torch.device, - mem_size : int = default_mem_size, # total memory size - ): + + def __init__( + self, + block_head_mem_size: int, # size of each block for key/value + num_layer: int, + num_block_slot: int, # number of slots for each block + data_type: torch.dtype, + device: torch.device, + mem_size: int = default_mem_size, # total memory size + ): t = 2 * num_layer * block_head_mem_size mem_size = (mem_size // t) * t - assert(mem_size % (2 * num_layer * block_head_mem_size) == 0) - assert(block_head_mem_size % data_type.itemsize == 0) - num_item = (block_head_mem_size // data_type.itemsize // num_block_slot) + assert (mem_size % (2 * num_layer * block_head_mem_size) == 0) + assert (block_head_mem_size % data_type.itemsize == 0) + num_item = (block_head_mem_size // data_type.itemsize // + num_block_slot) self.block_head_mem_size = block_head_mem_size self.num_block_slot = num_block_slot self.num_blocks = (mem_size // t) self.num_item = num_item self.device = device self.num_layer = num_layer - self.event_map: dict[int, - Optional[torch.cuda.Event, - list[torch.cuda.Event]]] = {} + self.event_map: dict[int, Optional[torch.cuda.Event, + list[torch.cuda.Event]]] = {} self.batch_layers_to_GPU = batch_layers_transmission_to_GPU with torch.cuda.device(device): - self.store = torch.empty([self.num_blocks, - 2, - num_layer, - num_block_slot, - num_item], - dtype=data_type, - device="cpu").pin_memory() + self.store = torch.empty( + [self.num_blocks, 2, num_layer, num_block_slot, num_item], + dtype=data_type, + device="cpu").pin_memory() self.get_stream = torch.cuda.Stream() self.put_stream = torch.cuda.Stream() self.put_event = torch.cuda.Event() self.event_pool = EventPool(100, num_layer, device) @classmethod - def from_configs(cls, - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, + def from_configs(cls, cache_config: Optional["CacheConfig"], + model_config: Optional["ModelConfig"], + parallel_config: Optional["ParallelConfig"], device: torch.device): dtype = None if (cache_config.cache_dtype == "auto"): @@ -377,39 +382,31 @@ def from_configs(cls, head_dim = model_config.get_head_size() num_layers = model_config.get_num_layers(parallel_config) - block_head_mem_size = (dtype.itemsize * block_size - * num_key_value_heads * head_dim) - return cls(block_head_mem_size, - num_layers, - block_size, - dtype, - device, + block_head_mem_size = (dtype.itemsize * block_size * + num_key_value_heads * head_dim) + return cls(block_head_mem_size, num_layers, block_size, dtype, device, cache_config.kv_store_space_bytes) - def put_block_layer(self, - incomplete_block_ids: torch.Tensor, - block_ids_mapping: torch.Tensor, - layer_id: int, + def put_block_layer(self, incomplete_block_ids: torch.Tensor, + block_ids_mapping: torch.Tensor, layer_id: int, kv_cache: torch.Tensor, forward_stream: torch.cuda.Stream): if (incomplete_block_ids.numel() == 0) and \ (block_ids_mapping.numel() == 0): return - assert(incomplete_block_ids.is_cuda == False) - assert(block_ids_mapping.is_cuda == False) - incomplete_block_ids_numpy = incomplete_block_ids.numpy() + assert (not incomplete_block_ids.is_cuda) + assert (not block_ids_mapping.is_cuda) self.put_event.record(forward_stream) self.put_event.wait(self.put_stream) if (block_ids_mapping.numel() != 0): with torch.cuda.stream(self.put_stream): - ops.kv_store_copy_blocks2CPU( - kv_cache, self.store, layer_id, - block_ids_mapping) + ops.kv_store_copy_blocks2CPU(kv_cache, self.store, layer_id, + block_ids_mapping) if (incomplete_block_ids.numel() != 0): with torch.cuda.stream(self.put_stream): ops.kv_store_copy_incomplete_blocks(kv_cache, self.store, - layer_id, - incomplete_block_ids) + layer_id, + incomplete_block_ids) def get_blocks(self, block_mapping_from_cpu: BlockMappingFromCPU, @@ -425,7 +422,6 @@ def get_blocks(self, event_list = [] if (is_batch_layer): # if batch layer, we need to allocate one event for each request - request_last_events = [] for idx, req_id in enumerate(request_ids_numpy): event = self.event_pool.get_event() self.event_map[req_id] = event @@ -438,20 +434,16 @@ def get_blocks(self, event_list.extend(event_list_tmp) with torch.cuda.stream(self.get_stream): ops.kv_store_copy_blocks2GPU( - self.store, kv_caches, - self.num_layer, - block_mapping_tensor, - block_offset_tensor, - request_ids_tensor, - [event.cuda_event for event in event_list], - is_batch_layer) + self.store, kv_caches, self.num_layer, block_mapping_tensor, + block_offset_tensor, request_ids_tensor, + [event.cuda_event for event in event_list], is_batch_layer) # pair used with get_blocks_batch def get_stream_sync(self, request_ids: torch.Tensor): if (request_ids.numel() == 0): return for req_id in request_ids.numpy(): - if (self.event_map.get(req_id) == None): + if (self.event_map.get(req_id) is None): continue event = self.event_map[req_id] event.synchronize() @@ -460,19 +452,17 @@ def get_stream_sync(self, request_ids: torch.Tensor): del self.event_map[req_id] # pair used with get_layer_blocks/get_blocks - def get_stream_layer_sync(self, - layer_id: int, - request_ids: torch.Tensor): + def get_stream_layer_sync(self, layer_id: int, request_ids: torch.Tensor): if (request_ids.numel() == 0): return for req_id in request_ids.numpy(): - if (self.event_map.get(req_id) == None): + if (self.event_map.get(req_id) is None): continue self.event_map[req_id][layer_id].synchronize() if (layer_id == self.num_layer - 1): # recycle the events for req_id in request_ids.numpy(): - if (self.event_map.get(req_id) == None): + if (self.event_map.get(req_id) is None): continue event_list = self.event_map[req_id] self.event_pool.put_events(event_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5e0417f5271cc..d8d9a5ec2bf2b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -22,17 +22,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.store.kv_store import KVBlockStore, KVStoreMeta, BlockMappingFromCPU from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) -from vllm.utils import init_logger -from vllm.store.kv_store import KVBlockStore, KVStoreMeta logger = init_logger(__name__) + class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -59,8 +59,8 @@ def __init__( if (self.cache_config.enable_kv_store): self.cache_config.kv_store = KVBlockStore.from_configs( - self.cache_config, self.model_config, self.parallel_config, - torch.device(f"cuda:{self.local_rank}")) + self.cache_config, self.model_config, self.parallel_config, + torch.device(f"cuda:{self.local_rank}")) self.kv_store = self.cache_config.kv_store self.kv_store_manager = self.cache_config.kv_store_manager @@ -122,7 +122,7 @@ def __init__( self.profiler = None def prepare_kv_store_meta(self, - is_prefill: bool, + is_prefill: Optional[bool], incomplete_put_block_ids: torch.Tensor, put_block_ids: torch.Tensor, seq_g_list: List[SequenceGroupMetadata]) \ @@ -143,15 +143,13 @@ def prepare_kv_store_meta(self, ret_seq_g_ids = torch.tensor(seq_g_ids, device="cpu", dtype=torch.int64).view(-1) - return KVStoreMeta(ret_incomplete_put_blocks, - ret_put_blocks_mapping, + return KVStoreMeta(ret_incomplete_put_blocks, ret_put_blocks_mapping, ret_seq_g_ids) def put_stream_sync(self): if (self.kv_store is not None): self.kv_store.put_stream_sync() - def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -386,8 +384,12 @@ def prepare_worker_input( blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, - kv_store_block_mapping_from_cpu=\ - execute_model_req.kv_store_block_mapping_from_cpu, + kv_store_block_mapping=\ + execute_model_req.kv_store_block_mapping_from_cpu.block_mapping, + kv_store_block_offsets=\ + execute_model_req.kv_store_block_mapping_from_cpu.block_offset, + kv_store_block_req_ids=\ + execute_model_req.kv_store_block_mapping_from_cpu.request_ids, ) @torch.inference_mode() @@ -408,15 +410,15 @@ def execute_worker(self, worker_input: WorkerInput) -> None: @torch.inference_mode() def issue_blocks_copy(self, worker_input: WorkerInput) -> None: - if (self.kv_store == None): + if (self.kv_store is None): return - kv_store_block_mapping_from_cpu = \ - worker_input.kv_store_block_mapping_from_cpu kv_caches = (self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None) - self.kv_store.get_blocks( - worker_input.kv_store_block_mapping_from_cpu, - kv_caches) + self.kv_store.get_blocks(BlockMappingFromCPU( + worker_input.kv_store_block_mapping, + worker_input.kv_store_block_offsets, + worker_input.kv_store_block_req_ids), + kv_caches) def _get_cached_seq_group_metadata( self, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 218919511553a..67093024b0c94 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -14,12 +14,12 @@ from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata) +from vllm.store.kv_store import BlockMappingFromCPU, KVStoreMeta from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) -from vllm.store.kv_store import KVStoreMeta,BlockMappingFromCPU logger = init_logger(__name__) @@ -150,7 +150,9 @@ class WorkerInput: blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 num_steps: int = 1 - kv_store_block_mapping_from_cpu: Optional[BlockMappingFromCPU] = None + kv_store_block_mapping: Optional[torch.Tensor] = None + kv_store_block_offsets: Optional[torch.Tensor] = None + kv_store_block_req_ids: Optional[torch.Tensor] = None @classmethod def from_broadcasted_tensor_dict( @@ -161,18 +163,16 @@ def from_broadcasted_tensor_dict( Pop fields from the given tensor_dict and populate a new instance of WorkerInput. """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - kv_store_block_mapping_from_cpu=BlockMappingFromCPU( - tensor_dict.pop("kv_block_mapping"), - tensor_dict.pop("kv_block_mapping_offsets"), - tensor_dict.pop("kv_block_mapping_req_ids") - ) + return cls(num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], + kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"), + kv_store_block_offsets=tensor_dict.pop( + "kv_block_mapping_offsets"), + kv_store_block_req_ids=tensor_dict.pop( + "kv_block_mapping_req_ids"), ) def as_broadcastable_tensor_dict( @@ -187,12 +187,9 @@ def as_broadcastable_tensor_dict( "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, - "kv_block_mapping": \ - self.kv_store_block_mapping_from_cpu.block_mapping, - "kv_block_mapping_offsets": \ - self.kv_store_block_mapping_from_cpu.block_offset, - "kv_block_mapping_req_ids": \ - self.kv_store_block_mapping_from_cpu.request_ids + "kv_block_mapping": self.kv_store_block_mapping, + "kv_block_mapping_offsets": self.kv_store_block_offsets, + "kv_block_mapping_req_ids": self.kv_store_block_req_ids, } return tensor_dict @@ -246,7 +243,7 @@ def prepare_worker_input( @abstractmethod def prepare_kv_store_meta(self, - is_prefill: bool, + is_prefill: Optional[bool], incomplete_put_block_ids: torch.Tensor, put_block_ids: torch.Tensor, seq_g_list: List[SequenceGroupMetadata]) \ @@ -271,7 +268,6 @@ def issue_blocks_copy(self, worker_input: WorkerInput) -> None: """ raise NotImplementedError - @abstractmethod def execute_worker(self, worker_input: WorkerInput) -> None: """ @@ -307,11 +303,10 @@ def _get_driver_input_and_broadcast( worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) + model_input = (self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) incomplete_put_block_ids = \ model_input.attn_metadata.kv_store_meta.incomplete_put_block_ids From 3e9fd6a959c7f55965cd8a211277d14db6b558cf Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 15:12:06 +0800 Subject: [PATCH 04/11] Refactor: pass lint check Signed-off-by: Dahai Tang --- vllm/worker/worker_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 67093024b0c94..705aa588075e1 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -168,6 +168,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"), kv_store_block_offsets=tensor_dict.pop( "kv_block_mapping_offsets"), From f272bd770e16b59378c9689b0f9560c22e60bc69 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 15:57:18 +0800 Subject: [PATCH 05/11] Feat(lint): pass the clang format checker Signed-off-by: Dahai Tang --- csrc/cache.h | 25 ++- csrc/cache_kernels.cu | 33 ++-- csrc/kv_store/kv_store.cu | 330 +++++++++++++++++--------------------- csrc/torch_bindings.cpp | 6 +- 4 files changed, 177 insertions(+), 217 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index ae99cdbaf9487..a86cfa16fdd42 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,21 +8,18 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -void kv_store_copy_incomplete_blocks(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, - const torch::Tensor& incomplete_block_mapping); +void kv_store_copy_incomplete_blocks( + torch::Tensor& src, torch::Tensor& dst, const int64_t layer_id, + const torch::Tensor& incomplete_block_mapping); void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, - const torch::Tensor& block_mapping); - -void kv_store_copy_blocks2GPU(torch::Tensor& src, - std::vector const& dst, - const int64_t num_layers, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const torch::Tensor& req_ids, - std::vector const& events, - const bool is_batch_layer); + const int64_t layer_id, + const torch::Tensor& block_mapping); + +void kv_store_copy_blocks2GPU( + torch::Tensor& src, std::vector const& dst, + const int64_t num_layers, const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, const torch::Tensor& req_ids, + std::vector const& events, const bool is_batch_layer); // Note: the key_caches and value_caches vectors are constant but // not the Tensors they contain. The vectors need to be const refs diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 81e3399324de9..343a368cb0968 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -28,7 +28,7 @@ namespace { KVStore kv_store; -}; // namespace +}; // namespace void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping) { @@ -73,8 +73,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, // src layout: [2, num_blocks, block_size, num_kv_heads, head_size] // dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -void kv_store_copy_incomplete_blocks(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, +void kv_store_copy_incomplete_blocks( + torch::Tensor& src, torch::Tensor& dst, const int64_t layer_id, const torch::Tensor& incomplete_block_mapping) { kv_store.CopyIncompleteBlocks(src, dst, layer_id, incomplete_block_mapping); } @@ -82,29 +82,26 @@ void kv_store_copy_incomplete_blocks(torch::Tensor& src, torch::Tensor& dst, // src layout: [2, num_blocks, block_size, num_kv_heads, head_size] // dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, - const torch::Tensor& block_mapping) { + const int64_t layer_id, + const torch::Tensor& block_mapping) { kv_store.CopyBlocks2CPU(src, dst, layer_id, block_mapping); } // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] -void kv_store_copy_blocks2GPU(torch::Tensor& src, - std::vector const& kv_caches, - const int64_t num_layers, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const torch::Tensor& req_ids, - std::vector const& events, - const bool is_batch_layer) { +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// head_size]] +void kv_store_copy_blocks2GPU( + torch::Tensor& src, std::vector const& kv_caches, + const int64_t num_layers, const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, const torch::Tensor& req_ids, + std::vector const& events, const bool is_batch_layer) { if (is_batch_layer) { const int64_t num_requests = req_ids.size(0); kv_store.CopyBlocks2GPUBatch(src, kv_caches, num_layers, block_mapping, - block_offsets, num_requests, events); - } - else { + block_offsets, num_requests, events); + } else { kv_store.CopyLayerBlocks2GPU(src, kv_caches, num_layers, block_mapping, - block_offsets, req_ids, events); + block_offsets, req_ids, events); } } diff --git a/csrc/kv_store/kv_store.cu b/csrc/kv_store/kv_store.cu index 14afff472a553..dd38941419fa5 100644 --- a/csrc/kv_store/kv_store.cu +++ b/csrc/kv_store/kv_store.cu @@ -1,19 +1,19 @@ #include "kv_store.hpp" - - -#define CHECK_CUDA(x) {\ - cudaError_t err = (x);\ - if (err != cudaSuccess) {\ - std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ - std::abort();\ - }\ -} +#define CHECK_CUDA(x) \ + { \ + cudaError_t err = (x); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + std::abort(); \ + } \ + } // src layout: [2, num_blocks, block_size, num_kv_heads, head_size] // dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -void KVStore::CopyIncompleteBlocks(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, +void KVStore::CopyIncompleteBlocks( + torch::Tensor& src, torch::Tensor& dst, const int64_t layer_id, const torch::Tensor& incomplete_block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); @@ -24,36 +24,36 @@ void KVStore::CopyIncompleteBlocks(torch::Tensor& src, torch::Tensor& dst, TORCH_CHECK(false, "only support copy from GPU to CPU"); } TORCH_CHECK(incomplete_block_mapping.device().is_cpu(), - "block_mapping must be on CPU"); + "block_mapping must be on CPU"); const int64_t slot_size_in_bytes = src.element_size() * src[0][0][0].numel(); const at::cuda::OptionalCUDAGuard device_guard(src_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int64_t num_items = incomplete_block_mapping.size(0); for (size_t i = 0; i < num_items; i++) { - int64_t src_block = incomplete_block_mapping[i][0].item(); + int64_t src_block = incomplete_block_mapping[i][0].item(); int64_t start_offset = incomplete_block_mapping[i][1].item(); - int64_t end_offset = incomplete_block_mapping[i][2].item(); - int64_t dst_block = incomplete_block_mapping[i][3].item(); - int64_t copy_nbytes = (end_offset - start_offset ) * slot_size_in_bytes; + int64_t end_offset = incomplete_block_mapping[i][2].item(); + int64_t dst_block = incomplete_block_mapping[i][3].item(); + int64_t copy_nbytes = (end_offset - start_offset) * slot_size_in_bytes; char* src_ptr = reinterpret_cast(src[0][src_block].data_ptr()); - char* dst_ptr = reinterpret_cast( - dst[dst_block][0][layer_id].data_ptr()); + char* dst_ptr = + reinterpret_cast(dst[dst_block][0][layer_id].data_ptr()); start_offset *= slot_size_in_bytes; CHECK_CUDA(cudaMemcpyAsync(dst_ptr + start_offset, src_ptr + start_offset, - copy_nbytes, memcpy_type, stream)); + copy_nbytes, memcpy_type, stream)); src_ptr = reinterpret_cast(src[1][src_block].data_ptr()); dst_ptr = reinterpret_cast(dst[dst_block][1][layer_id].data_ptr()); CHECK_CUDA(cudaMemcpyAsync(dst_ptr + start_offset, src_ptr + start_offset, - copy_nbytes, memcpy_type, stream)); + copy_nbytes, memcpy_type, stream)); } } // src layout: [2, num_blocks, block_size, num_kv_heads, head_size] // dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] void KVStore::CopyBlocks2CPU(torch::Tensor& src, torch::Tensor& dst, - const int64_t layer_id, - const torch::Tensor& block_mapping) { + const int64_t layer_id, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -62,12 +62,11 @@ void KVStore::CopyBlocks2CPU(torch::Tensor& src, torch::Tensor& dst, } else { TORCH_CHECK(false, "only support copy from GPU to CPU"); } - TORCH_CHECK(block_mapping.device().is_cpu(), - "block_mapping must be on CPU"); + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); const int64_t src_block_numel = src[0][0].numel(); const int64_t dst_block_numel = dst[0][0][0].numel(); TORCH_CHECK(src_block_numel == dst_block_numel, - "src and dst must have the same number of elements"); + "src and dst must have the same number of elements"); const int64_t block_size_in_bytes = src.element_size() * src_block_numel; const at::cuda::OptionalCUDAGuard device_guard(src_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -76,58 +75,53 @@ void KVStore::CopyBlocks2CPU(torch::Tensor& src, torch::Tensor& dst, int64_t src_block_number = block_mapping[i][0].item(); int64_t dst_block_number = block_mapping[i][1].item(); // key - char* src_ptr = reinterpret_cast( - src[0][src_block_number].data_ptr()); - char* dst_ptr = reinterpret_cast( - dst[dst_block_number][0][layer_id].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync( - dst_ptr, src_ptr, block_size_in_bytes, memcpy_type, stream)); + char* src_ptr = + reinterpret_cast(src[0][src_block_number].data_ptr()); + char* dst_ptr = + reinterpret_cast(dst[dst_block_number][0][layer_id].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_size_in_bytes, + memcpy_type, stream)); // value - src_ptr = reinterpret_cast( - src[1][src_block_number].data_ptr()); - dst_ptr = reinterpret_cast( - dst[dst_block_number][1][layer_id].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync( - dst_ptr, src_ptr, block_size_in_bytes, memcpy_type, stream)); + src_ptr = reinterpret_cast(src[1][src_block_number].data_ptr()); + dst_ptr = + reinterpret_cast(dst[dst_block_number][1][layer_id].data_ptr()); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_size_in_bytes, + memcpy_type, stream)); } } - namespace vllm { // Grid: (num_layers, num_blocks) template __global__ void kv_store_copy_blocks_kernel( - scalar_t *src, - int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, - const int64_t *block_mapping, const int64_t *block_offsets, - const int request_idx, - const int64_t numel_per_block) { + scalar_t* src, int64_t* key_cache_ptrs, int64_t* value_cache_ptrs, + const int64_t* block_mapping, const int64_t* block_offsets, + const int request_idx, const int64_t numel_per_block) { int64_t layer_idx = blockIdx.x; int64_t pair_idx = blockIdx.y; int64_t num_layer = gridDim.x; - scalar_t *key_cache = - reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t *value_cache = - reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t block_mapping_idx = block_offsets[request_idx] + pair_idx; int64_t dst_block_number = block_mapping[2 * block_mapping_idx + 1]; - scalar_t *key_block = key_cache + dst_block_number * numel_per_block; - scalar_t *src_key_block = (src + pair_idx * 2 * num_layer * numel_per_block - + layer_idx * numel_per_block); + scalar_t* key_block = key_cache + dst_block_number * numel_per_block; + scalar_t* src_key_block = (src + pair_idx * 2 * num_layer * numel_per_block + + layer_idx * numel_per_block); for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { key_block[i] = src_key_block[i]; } - scalar_t *value_block = value_cache + dst_block_number * numel_per_block; - scalar_t *src_value_block = (src - + (pair_idx * 2 + 1) * num_layer * numel_per_block - + layer_idx * numel_per_block); + scalar_t* value_block = value_cache + dst_block_number * numel_per_block; + scalar_t* src_value_block = + (src + (pair_idx * 2 + 1) * num_layer * numel_per_block + + layer_idx * numel_per_block); for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { value_block[i] = src_value_block[i]; } } -} // namespace vllm +} // namespace vllm std::ostream& operator<<(std::ostream& os, const std::vector& vec) { os << "["; @@ -141,16 +135,16 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { return os; } - // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// head_size]] void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, - std::vector const& kv_caches, - const int64_t num_layers, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const int64_t num_requests, - std::vector const& events) { + std::vector const& kv_caches, + const int64_t num_layers, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const int64_t num_requests, + std::vector const& events) { torch::Device src_device = src.device(); torch::Device dst_device = kv_caches[0].device(); cudaMemcpyKind memcpy_type; @@ -159,18 +153,16 @@ void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, } else { TORCH_CHECK(false, "only support copy from CPU to GPU"); } - TORCH_CHECK(block_mapping.device().is_cpu(), - "block_mapping must be on CPU"); - TORCH_CHECK(block_offsets.device().is_cpu(), - "block_offsets must be on CPU"); + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), "block_offsets must be on CPU"); const at::cuda::OptionalCUDAGuard device_guard(dst_device); auto stream = at::cuda::getCurrentCUDAStream(); - auto block_mapping_gpu = block_mapping.to(dst_device, - block_mapping.scalar_type(), - /*non_blocking=*/true); - auto block_offsets_gpu = block_offsets.to(dst_device, - block_offsets.scalar_type(), - /*non_blocking=*/true); + auto block_mapping_gpu = + block_mapping.to(dst_device, block_mapping.scalar_type(), + /*non_blocking=*/true); + auto block_offsets_gpu = + block_offsets.to(dst_device, block_offsets.scalar_type(), + /*non_blocking=*/true); // Create data structures for the kernel. // Create an array of pointers to the key and value caches. int64_t key_cache_ptrs[num_layers]; @@ -194,9 +186,8 @@ void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, const int64_t start_idx = block_offsets[i].item(); const int64_t end_idx = block_offsets[i + 1].item(); const int64_t num_blocks = end_idx - start_idx; - auto options = torch::TensorOptions() - .dtype(kv_caches[0].dtype()) - .device(dst_device); + auto options = + torch::TensorOptions().dtype(kv_caches[0].dtype()).device(dst_device); std::vector shape = src.sizes().vec(); shape[0] = num_blocks; // XXX: may cause out of memory in VLLM framework @@ -207,24 +198,24 @@ void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, char* src_ptr = reinterpret_cast(src[src_block_number].data_ptr()); char* dst_ptr = reinterpret_cast(trans_buffer[j].data_ptr()); int64_t trans_nbytes = src[0].element_size() * src[0].numel(); - CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, trans_nbytes, - memcpy_type, stream)); + CHECK_CUDA( + cudaMemcpyAsync(dst_ptr, src_ptr, trans_nbytes, memcpy_type, stream)); } const int numel_per_block = src[0][0][0].numel(); const dim3 grid(num_layers, num_blocks); const dim3 block(std::min(1024, numel_per_block)); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( src.scalar_type(), "kv_store_copy_blocks_kernel", ([&] { - vllm::kv_store_copy_blocks_kernel<<>>( - trans_buffer.data_ptr(), - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping_gpu.data_ptr(), - block_offsets_gpu.data_ptr(), - i, - numel_per_block); + vllm::kv_store_copy_blocks_kernel + <<>>( + trans_buffer.data_ptr(), + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping_gpu.data_ptr(), + block_offsets_gpu.data_ptr(), i, numel_per_block); })); - CHECK_CUDA(cudaEventRecord(reinterpret_cast(events[i]), stream)); + CHECK_CUDA( + cudaEventRecord(reinterpret_cast(events[i]), stream)); } } @@ -236,53 +227,43 @@ namespace vllm { // Grid: (num_blocks) template __global__ void kv_store_copy_blocks_kernel( - scalar_t* src, - scalar_t* key_cache, - scalar_t* value_cache, - const int64_t* block_mapping, - const int64_t* block_offsets, - const int request_idx, - const int64_t layer_id, - const int64_t num_layer, + scalar_t* src, scalar_t* key_cache, scalar_t* value_cache, + const int64_t* block_mapping, const int64_t* block_offsets, + const int request_idx, const int64_t layer_id, const int64_t num_layer, const int64_t numel_per_block) { - int pair_idx = blockIdx.x; int64_t block_mapping_idx = block_offsets[request_idx] + pair_idx; int64_t src_block_number = block_mapping[2 * block_mapping_idx]; int64_t dst_block_number = block_mapping[2 * block_mapping_idx + 1]; - scalar_t* src_key_block = src - + src_block_number * 2 * num_layer * numel_per_block - + layer_id * numel_per_block; - scalar_t* dst_key_block = key_cache - + dst_block_number * numel_per_block; + scalar_t* src_key_block = src + + src_block_number * 2 * num_layer * numel_per_block + + layer_id * numel_per_block; + scalar_t* dst_key_block = key_cache + dst_block_number * numel_per_block; for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { dst_key_block[i] = src_key_block[i]; } - scalar_t* src_value_block = src - + (src_block_number * 2 + 1) * num_layer * numel_per_block - + layer_id * numel_per_block; - scalar_t* dst_value_block = value_cache - + dst_block_number * numel_per_block; + scalar_t* src_value_block = + src + (src_block_number * 2 + 1) * num_layer * numel_per_block + + layer_id * numel_per_block; + scalar_t* dst_value_block = value_cache + dst_block_number * numel_per_block; for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { dst_value_block[i] = src_value_block[i]; } } -}; // namespace vllm +}; // namespace vllm namespace { // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// head_size]] void CopyLayerBlocks2GPUKernelFunc( - const torch::Tensor& src, - std::vector const& kv_caches, - const int64_t num_layer, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const torch::Tensor& req_ids, + const torch::Tensor& src, std::vector const& kv_caches, + const int64_t num_layer, const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, const torch::Tensor& req_ids, const std::vector& events, - const at::cuda::CUDAStream& stream) { // is the current stream + const at::cuda::CUDAStream& stream) { // is the current stream size_t num_requests = req_ids.size(0); const int numel_per_block = src[0][0][0].numel(); const int64_t block_nbytes = numel_per_block * src.element_size(); @@ -294,19 +275,16 @@ void CopyLayerBlocks2GPUKernelFunc( } else { TORCH_CHECK(false, "only support copy from CPU to GPU"); } - TORCH_CHECK(block_mapping.device().is_cpu(), - "block_mapping must be on CPU"); - TORCH_CHECK(block_offsets.device().is_cpu(), - "block_offsets must be on CPU"); - TORCH_CHECK(req_ids.device().is_cpu(), - "req_ids must be on CPU"); + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), "block_offsets must be on CPU"); + TORCH_CHECK(req_ids.device().is_cpu(), "req_ids must be on CPU"); const at::cuda::OptionalCUDAGuard device_guard(dst_device); - auto block_mapping_gpu = block_mapping.to(dst_device, - block_mapping.scalar_type(), - /*non_blocking=*/true); - auto block_offsets_gpu = block_offsets.to(dst_device, - block_offsets.scalar_type(), - /*non_blocking=*/true); + auto block_mapping_gpu = + block_mapping.to(dst_device, block_mapping.scalar_type(), + /*non_blocking=*/true); + auto block_offsets_gpu = + block_offsets.to(dst_device, block_offsets.scalar_type(), + /*non_blocking=*/true); for (size_t i = 0; i < num_requests; i++) { const int64_t req_id = req_ids[i].item(); const int64_t start_idx = block_offsets[i].item(); @@ -314,25 +292,21 @@ void CopyLayerBlocks2GPUKernelFunc( const int64_t num_blocks = end_idx - start_idx; for (int64_t layer_id = 0; layer_id < num_layer; layer_id++) { - if (num_blocks >= 2) { // if blocks are too many, use kernel + if (num_blocks >= 2) { // if blocks are too many, use kernel const dim3 grid(num_blocks); const dim3 block(std::min(1024, numel_per_block)); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( src.scalar_type(), "kv_store_copy_blocks_kernel", ([&] { vllm::kv_store_copy_blocks_kernel - <<>>( - src.data_ptr(), - kv_caches[layer_id][0].data_ptr(), - kv_caches[layer_id][1].data_ptr(), - block_mapping_gpu.data_ptr(), - block_offsets_gpu.data_ptr(), - i, - layer_id, - num_layer, - numel_per_block); + <<>>( + src.data_ptr(), + kv_caches[layer_id][0].data_ptr(), + kv_caches[layer_id][1].data_ptr(), + block_mapping_gpu.data_ptr(), + block_offsets_gpu.data_ptr(), i, layer_id, + num_layer, numel_per_block); })); - } - else { + } else { for (size_t j = 0; j < num_blocks; j++) { int64_t idx = (start_idx + j); int64_t src_block_number = block_mapping[idx][0].item(); @@ -341,37 +315,34 @@ void CopyLayerBlocks2GPUKernelFunc( src[src_block_number][0][layer_id].data_ptr()); char* dst_ptr = reinterpret_cast( kv_caches[layer_id][0][dst_block_number].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, - block_nbytes, memcpy_type, stream)); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_nbytes, + memcpy_type, stream)); src_ptr = reinterpret_cast( src[src_block_number][1][layer_id].data_ptr()); dst_ptr = reinterpret_cast( kv_caches[layer_id][1][dst_block_number].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, - block_nbytes, memcpy_type, stream)); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_nbytes, + memcpy_type, stream)); } } CHECK_CUDA(cudaEventRecord( - reinterpret_cast(events[i * num_layer + layer_id]), - stream)); + reinterpret_cast(events[i * num_layer + layer_id]), + stream)); } } } // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// head_size]] void CopyLayerBlocks2GPUThreadFunc( - const torch::Tensor& src, - std::vector const& kv_caches, - const int64_t num_layer, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const torch::Tensor& req_ids, - const std::vector& events, - const at::cuda::CUDAStream& stream) { + const torch::Tensor& src, std::vector const& kv_caches, + const int64_t num_layer, const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, const torch::Tensor& req_ids, + const std::vector& events, const at::cuda::CUDAStream& stream) { size_t num_requests = req_ids.size(0); const int64_t block_nbytes = - kv_caches[0][0][0].numel() * kv_caches[0].element_size(); + kv_caches[0][0][0].numel() * kv_caches[0].element_size(); torch::Device src_device = src.device(); torch::Device dst_device = kv_caches[0].device(); cudaMemcpyKind memcpy_type; @@ -380,12 +351,9 @@ void CopyLayerBlocks2GPUThreadFunc( } else { TORCH_CHECK(false, "only support copy from CPU to GPU"); } - TORCH_CHECK(block_mapping.device().is_cpu(), - "block_mapping must be on CPU"); - TORCH_CHECK(block_offsets.device().is_cpu(), - "block_offsets must be on CPU"); - TORCH_CHECK(req_ids.device().is_cpu(), - "req_ids must be on CPU"); + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + TORCH_CHECK(block_offsets.device().is_cpu(), "block_offsets must be on CPU"); + TORCH_CHECK(req_ids.device().is_cpu(), "req_ids must be on CPU"); const at::cuda::OptionalCUDAGuard device_guard(dst_device); for (size_t i = 0; i < num_requests; i++) { const int64_t req_id = req_ids[i].item(); @@ -401,35 +369,34 @@ void CopyLayerBlocks2GPUThreadFunc( src[src_block_number][0][layer_id].data_ptr()); char* dst_ptr = reinterpret_cast( kv_caches[layer_id][0][dst_block_number].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, - block_nbytes, memcpy_type, stream)); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_nbytes, memcpy_type, + stream)); src_ptr = reinterpret_cast( src[src_block_number][1][layer_id].data_ptr()); dst_ptr = reinterpret_cast( kv_caches[layer_id][1][dst_block_number].data_ptr()); - CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, - block_nbytes, memcpy_type, stream)); + CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, block_nbytes, memcpy_type, + stream)); } CHECK_CUDA(cudaEventRecord( - reinterpret_cast(events[i * num_layer + layer_id]), - stream)); + reinterpret_cast(events[i * num_layer + layer_id]), + stream)); } } } -}; // namespace - +}; // namespace // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, head_size]] -void KVStore::CopyLayerBlocks2GPU( - torch::Tensor& src, - std::vector const& kv_caches, - const int64_t num_layer, - const torch::Tensor& block_mapping, - const torch::Tensor& block_offsets, - const torch::Tensor& req_ids, - const std::vector& events) { +// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// head_size]] +void KVStore::CopyLayerBlocks2GPU(torch::Tensor& src, + std::vector const& kv_caches, + const int64_t num_layer, + const torch::Tensor& block_mapping, + const torch::Tensor& block_offsets, + const torch::Tensor& req_ids, + const std::vector& events) { if (block_mapping.size(0) == 0) { return; } @@ -445,7 +412,6 @@ void KVStore::CopyLayerBlocks2GPU( req_ids.clone(), events, stream); */ - CopyLayerBlocks2GPUKernelFunc(src, kv_caches, num_layer, - block_mapping, block_offsets, req_ids, events, stream); + CopyLayerBlocks2GPUKernelFunc(src, kv_caches, num_layer, block_mapping, + block_offsets, req_ids, events, stream); } - diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ac8c2fbd91f27..ed9e3592ab64e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -410,14 +410,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " int layer_id, " " Tensor incomplete_block_mapping) -> ()"); cache_ops.impl("kv_store_copy_incomplete_blocks", torch::kCUDA, - &kv_store_copy_incomplete_blocks); + &kv_store_copy_incomplete_blocks); // Copy the kv cache blocks from src(GPU) to dst(CPU), used for kv store. cache_ops.def( "kv_store_copy_blocks2CPU(Tensor src, Tensor! dst, " " int layer_id, Tensor block_mapping) -> ()"); cache_ops.impl("kv_store_copy_blocks2CPU", torch::kCUDA, - &kv_store_copy_blocks2CPU); + &kv_store_copy_blocks2CPU); // Copy the kv cache blocks from src(CPU) to dst(GPU), used for kv store. cache_ops.def( @@ -428,7 +428,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " int[] events," " bool is_batch_layer) -> ()"); cache_ops.impl("kv_store_copy_blocks2GPU", torch::kCUDA, - &kv_store_copy_blocks2GPU); + &kv_store_copy_blocks2GPU); // Copy the cache blocks from src to dst. cache_ops.def( From 243e4f91f8308120fd3283fcad419b3a240528ed Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 16:25:35 +0800 Subject: [PATCH 06/11] Refactor: pass more detail lint checker Signed-off-by: Dahai Tang --- csrc/cache_kernels.cu | 2 +- csrc/kv_store/kv_store.cu | 8 ++++---- vllm/core/scheduler.py | 15 ++++++++------- vllm/model_executor/models/qwen2.py | 6 +++--- vllm/store/kv_store.py | 21 +++++++++++++++------ vllm/worker/worker_base.py | 2 +- 6 files changed, 32 insertions(+), 22 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 343a368cb0968..20a485832f3b9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -88,7 +88,7 @@ void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst, } // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads, // head_size]] void kv_store_copy_blocks2GPU( torch::Tensor& src, std::vector const& kv_caches, diff --git a/csrc/kv_store/kv_store.cu b/csrc/kv_store/kv_store.cu index dd38941419fa5..f11fa45813eda 100644 --- a/csrc/kv_store/kv_store.cu +++ b/csrc/kv_store/kv_store.cu @@ -136,7 +136,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { } // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads, // head_size]] void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src, std::vector const& kv_caches, @@ -256,7 +256,7 @@ __global__ void kv_store_copy_blocks_kernel( namespace { // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads, // head_size]] void CopyLayerBlocks2GPUKernelFunc( const torch::Tensor& src, std::vector const& kv_caches, @@ -333,7 +333,7 @@ void CopyLayerBlocks2GPUKernelFunc( } // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads, // head_size]] void CopyLayerBlocks2GPUThreadFunc( const torch::Tensor& src, std::vector const& kv_caches, @@ -388,7 +388,7 @@ void CopyLayerBlocks2GPUThreadFunc( }; // namespace // src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size] -// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads, +// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads, // head_size]] void KVStore::CopyLayerBlocks2GPU(torch::Tensor& src, std::vector const& kv_caches, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 407715b57942c..481d0fcaf22c7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -284,7 +284,7 @@ def create_empty(cls) -> "SchedulerPrefillOutputs": seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=0, - kv_store_block_mapping_from_cpu=None, + kv_store_block_mapping_from_cpu=BlockMappingFromCPU.null(), ) @@ -935,15 +935,16 @@ def _stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs, max_num_batched_tokens, budget): + ret = False if (budget.num_batched_tokens >= self.scheduler_config.max_num_batched_tokens): - return True + ret = True if (num_new_tokens_uncached == 0 or not budget.can_schedule( num_new_tokens=num_new_tokens_uncached, num_new_seqs=num_new_seqs)): - return True - return False + ret = True + return ret kv_store_tmp_queue : Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and kv_store_waiting_queue: @@ -1062,13 +1063,13 @@ def _stop_schedule_prefill(num_new_tokens_uncached, waiting_queue.popleft() continue - if (self.kv_store_manager != None): + if (self.kv_store_manager is not None): self.kv_store_manager.is_prefill = seq_group.is_prefill() block_mapping_from_cpu = [] self._allocate(seq_group) - if (self.kv_store_manager != None): + if (self.kv_store_manager is not None): block_ids = self.block_manager.get_block_table( seq_group.get_seqs()[0]) block_mapping_from_cpu = \ @@ -1137,7 +1138,7 @@ def _stop_schedule_prefill(num_new_tokens_uncached, if len(seq_groups) > 0: self.prev_prompt = True - if (self.kv_store_manager != None) and \ + if (self.kv_store_manager is not None) and \ (len(kv_store_block_mapping) > 0): self.kv_store_manager.close_send_flags( [items[1] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index ec7e975f34e16..49b56201e7546 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -338,15 +338,15 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if (self.kv_store != None) and \ - (self.kv_store.batch_layers_to_GPU == True): + if (self.kv_store is not None) and \ + (self.kv_store.batch_layers_to_GPU): self.kv_store.get_stream_sync( attn_metadata.kv_store_meta.request_ids) for i in range(self.start_layer, self.end_layer): layer_id = (i - self.start_layer) if (self.kv_store is not None) and \ - (self.kv_store.batch_layers_to_GPU == False): + (not self.kv_store.batch_layers_to_GPU): self.kv_store.get_stream_layer_sync( layer_id, attn_metadata.kv_store_meta.request_ids) layer = self.layers[i] diff --git a/vllm/store/kv_store.py b/vllm/store/kv_store.py index 2107cd62bc1d3..1ed2b915e25f7 100644 --- a/vllm/store/kv_store.py +++ b/vllm/store/kv_store.py @@ -20,10 +20,12 @@ @dataclass class BlockMappingFromCPU: - block_mapping: torch.Tensor # 2-D tenso - block_offset: torch.Tensor # 1-D tensor, like offset array in CSR format + # 2-D tensor + block_mapping: Optional[torch.Tensor] + # 1-D tensor, like offset array in CSR format # the offset of each request in block_mapping - request_ids: torch.Tensor # request IDs + block_offset: Optional[torch.Tensor] + request_ids: Optional[torch.Tensor] # request IDs def __init__(self, block_mapping: list[list[int, int]], block_offset: list[int], request_ids: list[int]): @@ -37,6 +39,11 @@ def __init__(self, block_mapping: list[list[int, int]], device="cpu", dtype=torch.int64).view(-1) + @staticmethod + def null(): + return BlockMappingFromCPU( + torch.Tensor(), torch.Tensor(), torch.Tensor()) + def __str__(self): return "block_mapping: " + str(self.block_mapping) + \ " block_offset: " + str(self.block_offset) + \ @@ -45,12 +52,14 @@ def __str__(self): @dataclass class KVStoreMeta: - incomplete_put_block_ids: torch.Tensor # 4-D tensor: + # 4-D tensor: # vllm_block_id, # start_offset,end_offset, # store_block_id - put_block_ids_mapping: torch.Tensor # 2-D tensor: + incomplete_put_block_ids: torch.Tensor + # 2-D tensor: # vllm_block_id, store_block_id + put_block_ids_mapping: torch.Tensor request_ids: torch.Tensor # 1-D tensor @staticmethod @@ -257,7 +266,7 @@ def get_block_mapping_from_torch(self, vllm_block_ids: torch.Tensor) \ return ret_tensor def get_block_mapping_from_python(self, vllm_block_ids: list[int]) \ - -> list[tuple[int, int]]: + -> list[list[int, int]]: if (not self.is_prefill) or \ (len(vllm_block_ids) == 0): return [] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 705aa588075e1..58411d675ccc4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata) -from vllm.store.kv_store import BlockMappingFromCPU, KVStoreMeta +from vllm.store.kv_store import KVStoreMeta from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, From 2a9deb63d94c171f159609c5c859aba603a086d3 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 16:32:35 +0800 Subject: [PATCH 07/11] isort the import order Signed-off-by: Dahai Tang --- examples/offline_inference.py | 1 + vllm/core/scheduler.py | 2 +- vllm/sequence.py | 5 +++-- vllm/worker/worker.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 97447f703e667..c67967baed67d 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,6 @@ from vllm import LLM, SamplingParams + def print_outputs(llm, outputs): for output in outputs: prompt = output.prompt diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 481d0fcaf22c7..c1a0895b6f9bc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -16,8 +16,8 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) +from vllm.store.kv_store import BlockMappingFromCPU, KVBlockStoreManager from vllm.utils import Device, PyObjectCache -from vllm.store.kv_store import KVBlockStoreManager,BlockMappingFromCPU logger = init_logger(__name__) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2d06b5ddbc51a..d0c8d83601ec4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,9 +6,10 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import reduce -from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, + Mapping, Optional) from typing import Sequence as GenericSequence -from typing import TYPE_CHECKING, Set, Tuple, Union +from typing import Set, Tuple, Union import msgspec import torch diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d8d9a5ec2bf2b..bd85cb81edbc6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -22,7 +22,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.store.kv_store import KVBlockStore, KVStoreMeta, BlockMappingFromCPU +from vllm.store.kv_store import BlockMappingFromCPU, KVBlockStore, KVStoreMeta from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner From d9b650933b91adfcc76f05a3edb4c68488e5e7ec Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 4 Dec 2024 17:01:09 +0800 Subject: [PATCH 08/11] Refactor: lint checker Signed-off-by: Dahai Tang --- vllm/_custom_ops.py | 47 ++++++++++++++++------------- vllm/core/scheduler.py | 47 ++++++++++++++--------------- vllm/engine/arg_utils.py | 2 +- vllm/model_executor/models/qwen2.py | 11 +++---- vllm/store/__init__.py | 6 +--- vllm/store/kv_store.py | 7 ++--- vllm/worker/worker.py | 8 ++--- vllm/worker/worker_base.py | 21 ++++++------- 8 files changed, 72 insertions(+), 77 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ccb0e70e88074..8d3bc0d19e6a7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -923,30 +923,35 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) -def kv_store_copy_incomplete_blocks(src: torch.Tensor, dst: torch.Tensor, - layer_id: int, - incomplete_block_mapping: torch.Tensor) -> None: - torch.ops._C_cache_ops.kv_store_copy_incomplete_blocks(src, dst, - layer_id, - incomplete_block_mapping) + +def kv_store_copy_incomplete_blocks( + src: torch.Tensor, dst: torch.Tensor, layer_id: int, + incomplete_block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.kv_store_copy_incomplete_blocks( + src, dst, layer_id, incomplete_block_mapping) + def kv_store_copy_blocks2CPU(src: torch.Tensor, dst: torch.Tensor, - layer_id: int, - block_mapping: torch.Tensor) -> None: + layer_id: int, + block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.kv_store_copy_blocks2CPU(src, dst, layer_id, - block_mapping) - -def kv_store_copy_blocks2GPU(src: torch.Tensor, dst: List[torch.Tensor], - num_layers: int, - block_mapping: torch.Tensor, - block_offsets: torch.Tensor, - req_ids: torch.Tensor, - events: List[int], # the pointer of cudaEvent_t - is_batch_layer: bool) -> None: - torch.ops._C_cache_ops.kv_store_copy_blocks2GPU( - src, dst, num_layers, - block_mapping, block_offsets, - req_ids, events, is_batch_layer) + block_mapping) + + +def kv_store_copy_blocks2GPU( + src: torch.Tensor, + dst: List[torch.Tensor], + num_layers: int, + block_mapping: torch.Tensor, + block_offsets: torch.Tensor, + req_ids: torch.Tensor, + events: List[int], # the pointer of cudaEvent_t + is_batch_layer: bool) -> None: + torch.ops._C_cache_ops.kv_store_copy_blocks2GPU(src, dst, num_layers, + block_mapping, + block_offsets, req_ids, + events, is_batch_layer) + def convert_fp8(output: torch.Tensor, input: torch.Tensor, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c1a0895b6f9bc..9bb0a89ce3433 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -462,8 +462,9 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): request_id = (request_id, ) request_ids = set(request_id) - for state_queue in [self.waiting, self.running, - self.swapped, self.kv_store_waiting]: + for state_queue in [ + self.waiting, self.running, self.swapped, self.kv_store_waiting + ]: aborted_groups: List[SequenceGroup] = [] for seq_group in state_queue: if not request_ids: @@ -931,22 +932,19 @@ def _schedule_prefills( leftover_waiting_sequences: Deque[SequenceGroup] = deque() kv_store_leftover_waiting_sequences: Deque[SequenceGroup] = deque() - def _stop_schedule_prefill(num_new_tokens_uncached, - num_new_seqs, - max_num_batched_tokens, - budget): + def _stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs, + max_num_batched_tokens, budget): ret = False if (budget.num_batched_tokens >= self.scheduler_config.max_num_batched_tokens): ret = True - if (num_new_tokens_uncached == 0 or - not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs)): + if (num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs)): ret = True return ret - kv_store_tmp_queue : Deque[SequenceGroup] = deque() + kv_store_tmp_queue: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and kv_store_waiting_queue: seq_group = kv_store_waiting_queue[0] @@ -977,8 +975,8 @@ def _stop_schedule_prefill(num_new_tokens_uncached, kv_store_waiting_queue.popleft() continue - if (_stop_schedule_prefill(num_new_tokens_uncached, - num_new_seqs, + if (_stop_schedule_prefill( + num_new_tokens_uncached, num_new_seqs, self.scheduler_config.max_num_batched_tokens, budget)): break @@ -1071,7 +1069,7 @@ def _stop_schedule_prefill(num_new_tokens_uncached, if (self.kv_store_manager is not None): block_ids = self.block_manager.get_block_table( - seq_group.get_seqs()[0]) + seq_group.get_seqs()[0]) block_mapping_from_cpu = \ self.kv_store_manager.get_block_mapping_from_python( block_ids) @@ -1097,18 +1095,18 @@ def _stop_schedule_prefill(num_new_tokens_uncached, if (len(block_mapping_from_cpu) > 0): waiting_queue.popleft() kv_store_leftover_waiting_sequences.appendleft(seq_group) - kv_store_block_mapping.extend( - block_mapping_from_cpu) - kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt) + kv_store_block_mapping.extend(block_mapping_from_cpu) + kv_store_block_mapping_offset.append( + kv_store_block_mapping_cnt) kv_store_block_mapping_req_ids.append( - seq_group.get_seqs()[0].seq_id) + seq_group.get_seqs()[0].seq_id) kv_store_block_mapping_cnt += len(block_mapping_from_cpu) continue num_new_seqs = seq_group.get_max_num_running_seqs() - if (_stop_schedule_prefill(num_new_tokens_uncached, num_new_seqs, - self.scheduler_config.max_num_batched_tokens, - budget)): + if (_stop_schedule_prefill( + num_new_tokens_uncached, num_new_seqs, + self.scheduler_config.max_num_batched_tokens, budget)): # let it to the next running one waiting_queue.popleft() kv_store_leftover_waiting_sequences.appendleft(seq_group) @@ -1141,13 +1139,12 @@ def _stop_schedule_prefill(num_new_tokens_uncached, if (self.kv_store_manager is not None) and \ (len(kv_store_block_mapping) > 0): self.kv_store_manager.close_send_flags( - [items[1] - for items in kv_store_block_mapping]) + [items[1] for items in kv_store_block_mapping]) kv_store_block_mapping_offset.append(kv_store_block_mapping_cnt) kv_store_block_mapping_from_cpu = BlockMappingFromCPU( - kv_store_block_mapping, kv_store_block_mapping_offset, - kv_store_block_mapping_req_ids) + kv_store_block_mapping, kv_store_block_mapping_offset, + kv_store_block_mapping_req_ids) return SchedulerPrefillOutputs( seq_groups=seq_groups, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 59543f9fc5c13..15e85395a96dd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,7 +120,7 @@ class EngineArgs: use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB - kv_store_space: float = 0 # GiB + kv_store_space: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 49b56201e7546..6b96b748f54e2 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -341,14 +341,14 @@ def forward( if (self.kv_store is not None) and \ (self.kv_store.batch_layers_to_GPU): self.kv_store.get_stream_sync( - attn_metadata.kv_store_meta.request_ids) + attn_metadata.kv_store_meta.request_ids) for i in range(self.start_layer, self.end_layer): layer_id = (i - self.start_layer) if (self.kv_store is not None) and \ (not self.kv_store.batch_layers_to_GPU): self.kv_store.get_stream_layer_sync( - layer_id, attn_metadata.kv_store_meta.request_ids) + layer_id, attn_metadata.kv_store_meta.request_ids) layer = self.layers[i] hidden_states, residual = layer( positions, @@ -360,10 +360,9 @@ def forward( if (self.kv_store is not None): self.kv_store.put_block_layer( - attn_metadata.kv_store_meta.incomplete_put_block_ids, - attn_metadata.kv_store_meta.put_block_ids_mapping, - layer_id, kv_caches[layer_id], - torch.cuda.current_stream()) + attn_metadata.kv_store_meta.incomplete_put_block_ids, + attn_metadata.kv_store_meta.put_block_ids_mapping, + layer_id, kv_caches[layer_id], torch.cuda.current_stream()) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/store/__init__.py b/vllm/store/__init__.py index 6d6f5beca8aae..50cf3e56e478e 100644 --- a/vllm/store/__init__.py +++ b/vllm/store/__init__.py @@ -1,7 +1,3 @@ from vllm.store.kv_store import KVBlockStore, KVBlockStoreManager, KVStoreMeta -__all__ = [ - "KVBlockStore", - "KVBlockStoreManager", - "KVStoreMeta" -] +__all__ = ["KVBlockStore", "KVBlockStoreManager", "KVStoreMeta"] diff --git a/vllm/store/kv_store.py b/vllm/store/kv_store.py index 1ed2b915e25f7..acd0265bbe71c 100644 --- a/vllm/store/kv_store.py +++ b/vllm/store/kv_store.py @@ -41,8 +41,8 @@ def __init__(self, block_mapping: list[list[int, int]], @staticmethod def null(): - return BlockMappingFromCPU( - torch.Tensor(), torch.Tensor(), torch.Tensor()) + return BlockMappingFromCPU(torch.Tensor(), torch.Tensor(), + torch.Tensor()) def __str__(self): return "block_mapping: " + str(self.block_mapping) + \ @@ -417,8 +417,7 @@ def put_block_layer(self, incomplete_block_ids: torch.Tensor, layer_id, incomplete_block_ids) - def get_blocks(self, - block_mapping_from_cpu: BlockMappingFromCPU, + def get_blocks(self, block_mapping_from_cpu: BlockMappingFromCPU, kv_caches: list[torch.Tensor]): block_mapping_tensor = block_mapping_from_cpu.block_mapping block_offset_tensor = block_mapping_from_cpu.block_offset diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bd85cb81edbc6..0bf0b1df3fd1a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -414,10 +414,10 @@ def issue_blocks_copy(self, worker_input: WorkerInput) -> None: return kv_caches = (self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None) - self.kv_store.get_blocks(BlockMappingFromCPU( - worker_input.kv_store_block_mapping, - worker_input.kv_store_block_offsets, - worker_input.kv_store_block_req_ids), + self.kv_store.get_blocks( + BlockMappingFromCPU(worker_input.kv_store_block_mapping, + worker_input.kv_store_block_offsets, + worker_input.kv_store_block_req_ids), kv_caches) def _get_cached_seq_group_metadata( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 58411d675ccc4..c6548651be7ad 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -163,17 +163,16 @@ def from_broadcasted_tensor_dict( Pop fields from the given tensor_dict and populate a new instance of WorkerInput. """ - return cls(num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"), - kv_store_block_offsets=tensor_dict.pop( - "kv_block_mapping_offsets"), - kv_store_block_req_ids=tensor_dict.pop( - "kv_block_mapping_req_ids"), + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), + kv_store_block_mapping=tensor_dict.pop("kv_block_mapping"), + kv_store_block_offsets=tensor_dict.pop("kv_block_mapping_offsets"), + kv_store_block_req_ids=tensor_dict.pop("kv_block_mapping_req_ids"), ) def as_broadcastable_tensor_dict( From 11d934f7b42099e36ebbc0681304df92eeaf1f97 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 5 Dec 2024 02:17:08 +0000 Subject: [PATCH 09/11] Feat: move functions about cpu kv_store in worker.py to worker_base Signed-off-by: Dahai Tang --- vllm/worker/worker.py | 52 +-------------------------- vllm/worker/worker_base.py | 73 +++++++++++++++++++++++++------------- 2 files changed, 49 insertions(+), 76 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0bf0b1df3fd1a..617d5fe18cc36 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -22,7 +22,6 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.store.kv_store import BlockMappingFromCPU, KVBlockStore, KVStoreMeta from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -50,20 +49,12 @@ def __init__( is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: - WorkerBase.__init__(self, vllm_config) + WorkerBase.__init__(self, vllm_config, local_rank) self.parallel_config.rank = rank - self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - if (self.cache_config.enable_kv_store): - self.cache_config.kv_store = KVBlockStore.from_configs( - self.cache_config, self.model_config, self.parallel_config, - torch.device(f"cuda:{self.local_rank}")) - self.kv_store = self.cache_config.kv_store - self.kv_store_manager = self.cache_config.kv_store_manager - if is_driver_worker: assert rank % self.parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." @@ -121,35 +112,6 @@ def __init__( else: self.profiler = None - def prepare_kv_store_meta(self, - is_prefill: Optional[bool], - incomplete_put_block_ids: torch.Tensor, - put_block_ids: torch.Tensor, - seq_g_list: List[SequenceGroupMetadata]) \ - -> KVStoreMeta: - ret_incomplete_put_blocks = torch.Tensor() - ret_put_blocks_mapping = torch.Tensor() - ret_seq_g_ids = torch.Tensor() - if (self.local_rank == 0) and (self.kv_store_manager is not None): - self.kv_store_manager.is_prefill = is_prefill - (ret_incomplete_put_blocks, ret_put_blocks_mapping) = \ - self.kv_store_manager.get_put_blocks_mapping( - incomplete_put_block_ids, put_block_ids) - self.kv_store_manager.is_prefill = False - if (is_prefill) and (ret_incomplete_put_blocks.numel() + \ - ret_put_blocks_mapping.numel() > 0): - # XXX: use first seq_id representing the seq_group id - seq_g_ids = [seq_g.get_first_seq_id() for seq_g in seq_g_list] - ret_seq_g_ids = torch.tensor(seq_g_ids, - device="cpu", - dtype=torch.int64).view(-1) - return KVStoreMeta(ret_incomplete_put_blocks, ret_put_blocks_mapping, - ret_seq_g_ids) - - def put_stream_sync(self): - if (self.kv_store is not None): - self.kv_store.put_stream_sync() - def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -408,18 +370,6 @@ def execute_worker(self, worker_input: WorkerInput) -> None: and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - @torch.inference_mode() - def issue_blocks_copy(self, worker_input: WorkerInput) -> None: - if (self.kv_store is None): - return - kv_caches = (self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None) - self.kv_store.get_blocks( - BlockMappingFromCPU(worker_input.kv_store_block_mapping, - worker_input.kv_store_block_offsets, - worker_input.kv_store_block_req_ids), - kv_caches) - def _get_cached_seq_group_metadata( self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c6548651be7ad..fb1c659d47a01 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata) -from vllm.store.kv_store import KVStoreMeta +from vllm.store.kv_store import BlockMappingFromCPU, KVBlockStore, KVStoreMeta from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, @@ -33,6 +33,7 @@ class WorkerBase(ABC): def __init__( self, vllm_config: VllmConfig, + local_rank: int = 0, ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -46,6 +47,15 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.kv_transfer_config = vllm_config.kv_transfer_config + self.local_rank = local_rank + + if (self.cache_config.enable_kv_store): + self.cache_config.kv_store = KVBlockStore.from_configs( + self.cache_config, self.model_config, self.parallel_config, + torch.device(f"cuda:{self.local_rank}")) + self.kv_store = self.cache_config.kv_store + self.kv_store_manager = self.cache_config.kv_store_manager + @abstractmethod def init_device(self) -> None: @@ -242,38 +252,51 @@ def prepare_worker_input( raise NotImplementedError @abstractmethod + def execute_worker(self, worker_input: WorkerInput) -> None: + """ + Process an execution request. + """ + raise NotImplementedError + def prepare_kv_store_meta(self, is_prefill: Optional[bool], incomplete_put_block_ids: torch.Tensor, put_block_ids: torch.Tensor, seq_g_list: List[SequenceGroupMetadata]) \ -> KVStoreMeta: - """ - Prepare the KVStoreMeta for the worker. This is called by the driver - worker to prepare the metadata for the KVStore. - """ - pass - - @abstractmethod - def put_stream_sync(self) -> None: - """ - Synchronize the stream. - """ - raise NotImplementedError + ret_incomplete_put_blocks = torch.Tensor() + ret_put_blocks_mapping = torch.Tensor() + ret_seq_g_ids = torch.Tensor() + if (self.local_rank == 0) and (self.kv_store_manager is not None): + self.kv_store_manager.is_prefill = is_prefill + (ret_incomplete_put_blocks, ret_put_blocks_mapping) = \ + self.kv_store_manager.get_put_blocks_mapping( + incomplete_put_block_ids, put_block_ids) + self.kv_store_manager.is_prefill = False + if (is_prefill) and (ret_incomplete_put_blocks.numel() + \ + ret_put_blocks_mapping.numel() > 0): + # XXX: use first seq_id representing the seq_group id + seq_g_ids = [seq_g.get_first_seq_id() for seq_g in seq_g_list] + ret_seq_g_ids = torch.tensor(seq_g_ids, + device="cpu", + dtype=torch.int64).view(-1) + return KVStoreMeta(ret_incomplete_put_blocks, ret_put_blocks_mapping, + ret_seq_g_ids) + + def put_stream_sync(self): + if (self.kv_store is not None): + self.kv_store.put_stream_sync() - @abstractmethod def issue_blocks_copy(self, worker_input: WorkerInput) -> None: - """ - Issue the copy of the blocks from CPU to GPU with the given indices. - """ - raise NotImplementedError - - @abstractmethod - def execute_worker(self, worker_input: WorkerInput) -> None: - """ - Process an execution request. - """ - raise NotImplementedError + if (self.kv_store is None): + return + kv_caches = (self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None) + self.kv_store.get_blocks( + BlockMappingFromCPU(worker_input.kv_store_block_mapping, + worker_input.kv_store_block_offsets, + worker_input.kv_store_block_req_ids), + kv_caches) def _get_worker_input_from_broadcast( self From 068342c9d5fe76956f122a2797abc9071293dfa5 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 5 Dec 2024 06:51:15 +0000 Subject: [PATCH 10/11] Feat: add kv_store_meta for common attation meta Signed-off-by: Dahai Tang --- vllm/attention/backends/abstract.py | 4 ++ vllm/attention/backends/blocksparse_attn.py | 2 + vllm/attention/backends/flash_attn.py | 3 -- vllm/attention/backends/flashinfer.py | 2 + vllm/attention/backends/placeholder_attn.py | 4 ++ vllm/attention/backends/utils.py | 58 +++++++++++++++++++-- vllm/attention/layer.py | 47 +++++++++++++---- vllm/model_executor/models/qwen2.py | 22 +------- vllm/store/kv_store.py | 23 +++++++- vllm/worker/worker_base.py | 1 - 10 files changed, 125 insertions(+), 41 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..8db1b3b23f95f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -7,6 +7,7 @@ import torch from vllm.multimodal import MultiModalPlaceholderMap +from vllm.store.kv_store import KVStoreMeta if TYPE_CHECKING: from vllm.worker.model_runner_base import (ModelRunnerBase, @@ -115,6 +116,9 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # CPU KV store metadata + kv_store_meta: KVStoreMeta + # The index maps that relate multi-modal embeddings to the corresponding # placeholders. # diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 99cb84346d84e..e91286fcbd37f 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ def prefill_metadata( slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + kv_store_meta=self.kv_store_meta, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + kv_store_meta=self.kv_store_meta, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 35685917d668c..45fd33e46099c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -131,9 +131,6 @@ class FlashAttentionMetadata(AttentionMetadata): # captured. block_tables: Optional[torch.Tensor] - # CPU KV store metadata - kv_store_meta: Optional[KVStoreMeta] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..3a15fb556924f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -30,6 +30,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention +from vllm.store.kv_store import KVStoreMeta from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + kv_store_meta=KVStoreMeta.null(), num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 888adbffb8578..a9c20a4ec3dac 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -9,6 +9,7 @@ AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap +from vllm.store.kv_store import KVStoreMeta if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -139,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + kv_store_meta=self.kv_store_meta, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -172,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + kv_store_meta=self.kv_store_meta, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -316,6 +319,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + kv_store_meta=KVStoreMeta.null(), num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 1cd11e262748e..c4093ba51d310 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -147,6 +147,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size + self.enable_kv_store = input_builder.enable_kv_store + self.range_list: List[Tuple[int, int, List[int], bool]] = [] def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -199,9 +201,12 @@ def _add_seq_group( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + (range_start, range_end, block_table) = compute_slot_mapping( + is_profile_run, self.slot_mapping, seq_id, seq_len, + context_len, start_idx, self.block_size, + inter_data.block_tables) + self.range_list.append( + (range_start, range_end, block_table, is_prompt)) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): @@ -268,6 +273,50 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.multimodal_placeholder_maps.items() } + incomplete_put_block_ids = [] + put_block_ids = [] + assert (len(self.range_list) == len(self.block_tables)) + if (self.enable_kv_store): + for (range_start, range_end, seq_block_table, is_prompt) in \ + self.range_list: + if (range_start == range_end) or (not is_prompt): + continue + block_size = self.block_size + range_end -= 1 + range_start_block_id = range_start // block_size + range_end_block_id = range_end // block_size + range_start_block_offset = range_start % block_size + range_end_block_offset = range_end % block_size + 1 + if (range_start_block_id == range_end_block_id): + incomplete_put_block_ids.append([ + seq_block_table[range_start_block_id], + range_start_block_offset, range_end_block_offset + ]) + else: + if (range_start_block_offset == 0): + put_block_ids.append( + seq_block_table[range_start_block_id]) + else: + incomplete_put_block_ids.append([ + seq_block_table[range_start_block_id], + range_start_block_offset, block_size + ]) + put_block_ids.extend(seq_block_table[range_start_block_id + + 1:range_end_block_id]) + if (range_end_block_offset == block_size): + put_block_ids.append( + seq_block_table[range_end_block_id]) + else: + incomplete_put_block_ids.append([ + seq_block_table[range_end_block_id], 0, + range_end_block_offset + ]) + incomplete_put_block_ids_numpy = np.array(incomplete_put_block_ids) + put_block_ids_numpy = np.array(put_block_ids) + incomplete_put_block_ids_cpu = torch.from_numpy( + incomplete_put_block_ids_numpy).to("cpu") + put_block_ids_cpu = torch.from_numpy(put_block_ids_numpy).to("cpu") + return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -284,7 +333,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, - ) + kv_store_meta=KVStoreMeta(incomplete_put_block_ids_cpu, + put_block_ids_cpu, torch.Tensor())) class CommonAttentionState(AttentionState): diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e024eef286f05..efd0789316b8c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -119,6 +119,17 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + assert (cache_config is not None) + self.kv_store = cache_config.kv_store + + def put_blocks(self, incomplete_put_block_ids: torch.Tensor, + put_block_ids_mapping: torch.Tensor, + kv_cache: torch.Tensor) -> None: + if self.kv_store is not None: + self.kv_store.put_blocks(incomplete_put_block_ids, + put_block_ids_mapping, kv_cache, + torch.cuda.current_stream()) + def forward( self, query: torch.Tensor, @@ -129,15 +140,23 @@ def forward( attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: + if (self.kv_store is not None): + self.kv_store.get_blocks_sync( + attn_metadata.kv_store_meta.request_ids) + if self.use_direct_call: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + ret = self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + self.put_blocks( + attn_metadata.kv_store_meta.incomplete_put_block_ids, + attn_metadata.kv_store_meta.put_block_ids_mapping, kv_cache) + return ret elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -153,11 +172,17 @@ def forward( torch.ops.vllm.unified_attention_with_output( query, key, value, output, kv_cache, attn_type, self.layer_name) + self.put_blocks( + attn_metadata.kv_store_meta.incomplete_put_block_ids, + attn_metadata.kv_store_meta.put_block_ids_mapping, kv_cache) return output.view(-1, hidden_size) else: - return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, attn_type, - self.layer_name) + ret = torch.ops.vllm.unified_attention(query, key, value, kv_cache, + attn_type, self.layer_name) + self.put_blocks( + attn_metadata.kv_store_meta.incomplete_put_block_ids, + attn_metadata.kv_store_meta.put_block_ids_mapping, kv_cache) + return ret def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 6b96b748f54e2..7d4cc4b69e614 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -313,9 +313,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - self.enable_prefix_caching = cache_config.enable_prefix_caching - self.kv_store = cache_config.kv_store - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -338,32 +335,15 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if (self.kv_store is not None) and \ - (self.kv_store.batch_layers_to_GPU): - self.kv_store.get_stream_sync( - attn_metadata.kv_store_meta.request_ids) - for i in range(self.start_layer, self.end_layer): - layer_id = (i - self.start_layer) - if (self.kv_store is not None) and \ - (not self.kv_store.batch_layers_to_GPU): - self.kv_store.get_stream_layer_sync( - layer_id, attn_metadata.kv_store_meta.request_ids) layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[layer_id], + kv_caches[i - self.start_layer], attn_metadata, residual, ) - - if (self.kv_store is not None): - self.kv_store.put_block_layer( - attn_metadata.kv_store_meta.incomplete_put_block_ids, - attn_metadata.kv_store_meta.put_block_ids_mapping, - layer_id, kv_caches[layer_id], torch.cuda.current_stream()) - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/store/kv_store.py b/vllm/store/kv_store.py index acd0265bbe71c..79166e9d4fe39 100644 --- a/vllm/store/kv_store.py +++ b/vllm/store/kv_store.py @@ -366,6 +366,8 @@ def __init__( self.event_map: dict[int, Optional[torch.cuda.Event, list[torch.cuda.Event]]] = {} self.batch_layers_to_GPU = batch_layers_transmission_to_GPU + self._put_blocks_cur_layer = 0 + self._get_blocks_sync_cur_layer = 0 with torch.cuda.device(device): self.store = torch.empty( [self.num_blocks, 2, num_layer, num_block_slot, num_item], @@ -417,6 +419,16 @@ def put_block_layer(self, incomplete_block_ids: torch.Tensor, layer_id, incomplete_block_ids) + # called by each attention layer + def put_blocks(self, incomplete_block_ids: torch.Tensor, + block_ids_mapping: torch.Tensor, kv_cache: torch.Tensor, + forward_stream: torch.cuda.Stream): + layer_id = self._put_blocks_cur_layer + self.put_block_layer(incomplete_block_ids, block_ids_mapping, layer_id, + kv_cache, forward_stream) + self._put_blocks_cur_layer = \ + (self._put_blocks_cur_layer + 1) % self.num_layer + def get_blocks(self, block_mapping_from_cpu: BlockMappingFromCPU, kv_caches: list[torch.Tensor]): block_mapping_tensor = block_mapping_from_cpu.block_mapping @@ -447,7 +459,7 @@ def get_blocks(self, block_mapping_from_cpu: BlockMappingFromCPU, [event.cuda_event for event in event_list], is_batch_layer) # pair used with get_blocks_batch - def get_stream_sync(self, request_ids: torch.Tensor): + def get_stream_batch_sync(self, request_ids: torch.Tensor): if (request_ids.numel() == 0): return for req_id in request_ids.numpy(): @@ -476,5 +488,14 @@ def get_stream_layer_sync(self, layer_id: int, request_ids: torch.Tensor): self.event_pool.put_events(event_list) del self.event_map[req_id] + def get_blocks_sync(self, request_ids: torch.Tensor): + layer_id = self._get_blocks_sync_cur_layer + if (self.batch_layers_to_GPU) and (layer_id == 0): + self.get_stream_batch_sync(request_ids) + if (not self.batch_layers_to_GPU): + self.get_stream_layer_sync(layer_id, request_ids) + self._get_blocks_sync_cur_layer = \ + (self._get_blocks_sync_cur_layer + 1) % self.num_layer + def put_stream_sync(self): self.put_stream.synchronize() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fb1c659d47a01..36fabfc43c3b9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -56,7 +56,6 @@ def __init__( self.kv_store = self.cache_config.kv_store self.kv_store_manager = self.cache_config.kv_store_manager - @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device From 8d2816b713cfcb7ce844ab48bda08c13bc33ac11 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 5 Dec 2024 07:47:09 +0000 Subject: [PATCH 11/11] Fix: attn make_metadata with kv_store_meta Signed-off-by: Dahai Tang --- tests/kernels/utils.py | 9 +++++++-- vllm/attention/backends/flashinfer.py | 4 +++- vllm/worker/hpu_model_runner.py | 8 ++++++-- vllm/worker/tpu_model_runner.py | 3 +++ vllm/worker/xpu_model_runner.py | 3 +++ 5 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e7865fb2500ef..44aeac6f940d7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul +from vllm.store.kv_store import KVStoreMeta from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -925,7 +926,9 @@ def make_test_metadata( cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_kv_mmap.block_tables), + kv_store_meta=KVStoreMeta.null(), + ) else: # not is_prompt # Decode-phase scenario @@ -975,7 +978,9 @@ def make_test_metadata( cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_kv_mmap.block_tables), + kv_store_meta=KVStoreMeta.null(), + ) def assert_actual_matches_ideal(test_params: PhaseTestParameters, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a15fb556924f..2b34cf56bd28f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -237,7 +237,9 @@ def graph_capture_get_metadata_for_batch( q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None) + prefill_wrapper=None, + kv_store_meta=KVStoreMeta.null(), + ) attn_metadata.begin_forward() return attn_metadata diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 99cf9a7e67256..e48324c225258 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -40,6 +40,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) +from vllm.store.kv_store import KVStoreMeta from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -888,7 +889,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + kv_store_meta=KVStoreMeta.null(), ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1042,7 +1044,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + kv_store_meta=KVStoreMeta.null(), + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 9a054eb8a4cf7..acde12b9580ca 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -19,6 +19,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.store.kv_store import KVStoreMeta from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -183,6 +184,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + kv_store_meta=KVStoreMeta.null(), multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, @@ -205,6 +207,7 @@ def _dummy_run( block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, + kv_store_meta=KVStoreMeta.null(), ) else: assert seq_len == 1 diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e6322e095bbb9..cce722df55e66 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -22,6 +22,7 @@ MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.store.kv_store import KVStoreMeta from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( @@ -270,6 +271,7 @@ def _prepare_prompt( num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), + kv_store_meta=KVStoreMeta.null(), ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -354,6 +356,7 @@ def _prepare_decode( num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, + kv_store_meta=KVStoreMeta.null(), ) return ( input_tokens,