Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KV-Cache int8 quant support #10354

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include <string>
#include <cstdint>
#include "dtype_fp8.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"

#ifdef USE_ROCM
#include <hip/hip_bf16.h>
Expand Down Expand Up @@ -105,7 +109,10 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int quant_group,
const float* __restrict__ k_scales,
const float* __restrict__ v_scales,
const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
Expand Down Expand Up @@ -151,6 +158,16 @@ __device__ void paged_attention_kernel(
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
float k_scale = 0;
float v_scale = 0;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group128) {
int64_t tgt_kvs_idx = floor((kv_head_idx*HEAD_SIZE)/quant_group);
k_scale = *reinterpret_cast<const float*>(k_scales+tgt_kvs_idx);
v_scale = *reinterpret_cast<const float*>(v_scales+tgt_kvs_idx);
} else {
k_scale = *reinterpret_cast<const float*>(k_scales);
v_scale = *reinterpret_cast<const float*>(v_scales);
}
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

Expand Down Expand Up @@ -280,6 +297,17 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// int8 kv-cache
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group0) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant, k_scale, 0);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group128) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant, k_scale, 0);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
Expand Down Expand Up @@ -410,6 +438,21 @@ __device__ void paged_attention_kernel(

if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
// int8 kv-cache
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group0) {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
v_scale,
0);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8Group128) {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
v_scale,
0);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
Expand Down Expand Up @@ -513,15 +556,19 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int quant_group,
const float* __restrict__ k_scales,
const float* __restrict__ v_scales,
const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
kv_head_stride, quant_group, k_scales, v_scales,
tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
Expand Down Expand Up @@ -549,14 +596,17 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int quant_group,
const float* __restrict__ k_scales,
const float* __restrict__ v_scales,
const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
kv_block_stride, kv_head_stride, quant_group, k_scales, v_scales, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ template <>
struct FloatVec<uint4> {
using Type = Float8_;
};
template<>
struct FloatVec<uint8_t> {
using Type = float;
};

// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
Expand Down
4 changes: 4 additions & 0 deletions csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
// Layerwise int8 kv cache
kInt8Group0 = 3,
// Groupwise int8 kv cache
kInt8Group128 = 4,
};

// fp8 vector types for quantization of kv cache
Expand Down
23 changes: 17 additions & 6 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
quant_group, k_scales_ptr, v_scales_ptr, \
tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand All @@ -53,8 +54,11 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes,
int quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand All @@ -78,6 +82,8 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
float* k_scales_ptr = reinterpret_cast<float*>(k_scales.data_ptr());
float* v_scales_ptr = reinterpret_cast<float*>(v_scales.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

Expand Down Expand Up @@ -135,10 +141,12 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, quant_group, k_scales, v_scales, \
tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);


#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
Expand Down Expand Up @@ -177,7 +185,10 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype,
const int64_t quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
Expand All @@ -190,4 +201,4 @@ void paged_attention_v1(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
23 changes: 17 additions & 6 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, \
quant_group, k_scales_ptr, v_scales_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
Expand All @@ -54,8 +55,11 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes,
int quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand All @@ -82,6 +86,8 @@ void paged_attention_v2_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
float* k_scales_ptr = reinterpret_cast<float*>(k_scales.data_ptr());
float* v_scales_ptr = reinterpret_cast<float*>(v_scales.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();

Expand Down Expand Up @@ -142,7 +148,8 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
quant_group, k_scales, v_scales, \
tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand Down Expand Up @@ -188,16 +195,20 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const std::string& kv_cache_dtype,
const int64_t quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}

#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
12 changes: 8 additions & 4 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale);
const std::string& kv_cache_dtype,
const int64_t quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales);

void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

const int64_t quant_group,
torch::Tensor& k_scales,
torch::Tensor& v_scales);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
Loading
Loading