Skip to content

Commit

Permalink
[V1] Further reduce CPU overheads in flash-attn (#10989)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Dec 9, 2024
1 parent edc4fa3 commit 3b61cb4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
14 changes: 12 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,20 @@ void reshape_and_cache_flash(
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(1);
Expand Down
21 changes: 16 additions & 5 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,25 @@ def forward(
# Profiling run.
return output

num_actual_tokens = attn_metadata.num_actual_tokens
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
Expand Down

0 comments on commit 3b61cb4

Please sign in to comment.