Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Efficient CPU prefix caching for the prefill step #10888

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")

set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/kv_store/kv_store.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/pos_encoding_kernels.cu"
Expand Down
13 changes: 13 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);

void kv_store_copy_incomplete_blocks(
torch::Tensor& src, torch::Tensor& dst, const int64_t layer_id,
const torch::Tensor& incomplete_block_mapping);
void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst,
const int64_t layer_id,
const torch::Tensor& block_mapping);

void kv_store_copy_blocks2GPU(
torch::Tensor& src, std::vector<torch::Tensor> const& dst,
const int64_t num_layers, const torch::Tensor& block_mapping,
const torch::Tensor& block_offsets, const torch::Tensor& req_ids,
std::vector<long> const& events, const bool is_batch_layer);

// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
Expand Down
43 changes: 43 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand All @@ -21,6 +22,14 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include "kv_store/kv_store.hpp"

namespace {

KVStore kv_store;

}; // namespace

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
Expand Down Expand Up @@ -62,6 +71,40 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

// src layout: [2, num_blocks, block_size, num_kv_heads, head_size]
// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
void kv_store_copy_incomplete_blocks(
torch::Tensor& src, torch::Tensor& dst, const int64_t layer_id,
const torch::Tensor& incomplete_block_mapping) {
kv_store.CopyIncompleteBlocks(src, dst, layer_id, incomplete_block_mapping);
}

// src layout: [2, num_blocks, block_size, num_kv_heads, head_size]
// dst layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst,
const int64_t layer_id,
const torch::Tensor& block_mapping) {
kv_store.CopyBlocks2CPU(src, dst, layer_id, block_mapping);
}

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void kv_store_copy_blocks2GPU(
torch::Tensor& src, std::vector<torch::Tensor> const& kv_caches,
const int64_t num_layers, const torch::Tensor& block_mapping,
const torch::Tensor& block_offsets, const torch::Tensor& req_ids,
std::vector<long> const& events, const bool is_batch_layer) {
if (is_batch_layer) {
const int64_t num_requests = req_ids.size(0);
kv_store.CopyBlocks2GPUBatch(src, kv_caches, num_layers, block_mapping,
block_offsets, num_requests, events);
} else {
kv_store.CopyLayerBlocks2GPU(src, kv_caches, num_layers, block_mapping,
block_offsets, req_ids, events);
}
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
Loading
Loading