Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XQA kernel for masked multihead attention #9672

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,17 @@ set(VLLM_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/xqa/decoder_mmha/xqa_kernel_launcher.cu"
"csrc/xqa/decoder_mmha/decoder_xqa_impl_precompiled.cpp"
"csrc/xqa/decoder_mmha/decoder_xqa_runner.cpp"
"csrc/xqa/decoder_mmha/decoder_xqa_impl_common.cpp"
"csrc/xqa/decoder_mmha/decoder_xqa_impl.cpp"
"csrc/xqa/decoder_mmha/env_utils.cpp"
"csrc/torch_bindings.cpp")

file(GLOB XQA_CUBIN_CPP_SOURCES "csrc/xqa/kernels/decoderMaskedMultiheadAttention/cubin/*.cubin.cpp")
list(APPEND VLLM_EXT_SRC ${XQA_CUBIN_CPP_SOURCES})

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
Expand Down Expand Up @@ -319,7 +328,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# message(STATUS "Enabling moe extension.")
#add_dependencies(default _moe_C)

endif()
5 changes: 5 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

void reshape_and_cache_xqa(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
82 changes: 82 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,50 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_xqa_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ kv_cache, // [num_blocks, 2, num_heads, block_size,
// head_size], k_cache, v_cache
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size * 2;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_idx = block_idx * block_stride +
head_idx * block_size * head_size +
block_offset * head_size + head_offset;
const int64_t tgt_value_idx = (block_idx + 1) * block_stride +
head_idx * block_size * head_size +
block_offset * head_size + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
kv_cache[tgt_key_idx] = tgt_key;
kv_cache[tgt_value_idx] = tgt_value;
} else {
kv_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
kv_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
}
}
}
} // namespace vllm

// KV_T is the stored data type of kv-cache.
Expand Down Expand Up @@ -329,6 +373,44 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH);
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_XQA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_xqa_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);

void reshape_and_cache_xqa(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& kv_cache, // [num_blocks, 2, num_heads, block_size,
// head_size], k_cache, v_cache
torch::Tensor& slot_mapping, // [num_tokens] k and v shared
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = kv_cache.size(3);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = kv_cache.stride(1);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_XQA);
}

namespace vllm {

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
Expand Down
8 changes: 7 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ void paged_attention_v2(
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);

void xqa_paged_attention(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_value_cache, int64_t num_heads,
int64_t num_kv_heads, int64_t rotary_embedding_dim,
double scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::string kv_cache_dtype,
double k_scale, double v_scale);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);

Expand Down
20 changes: 19 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);

// PagedAttention xqa.
ops.def(
"xqa_paged_attention("
" Tensor! out,"
" Tensor query, Tensor key_value_cache,"
" int num_heads, int num_kv_heads, int rotary_embedding_dim,"
" float scale, Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
ops.impl("xqa_paged_attention", torch::kCUDA, &xqa_paged_attention);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
Expand Down Expand Up @@ -318,6 +327,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);

// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_xqa(Tensor key, Tensor value,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache_xqa", torch::kCUDA, &reshape_and_cache_xqa);

// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
Expand Down
Loading