From 76b27063eee7f4ed8c2246e68bb2a2d93a48de9d Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 5 Sep 2024 09:16:52 -0700 Subject: [PATCH] Reshape cache to be xqa kernel compatible --- csrc/cache.h | 6 +++ csrc/cache_kernels.cu | 39 +++++++++++++++ csrc/torch_bindings.cpp | 10 ++++ tests/kernels/conftest.py | 7 ++- tests/kernels/test_cache.py | 96 +++++++++++++++++++++++++++++++++++++ vllm/_custom_ops.py | 14 ++++++ vllm/utils.py | 35 ++++++++++++++ 7 files changed, 206 insertions(+), 1 deletion(-) diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..e29ad86ae4220 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,6 +28,12 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, const double k_scale, const double v_scale); +void reshape_and_cache_xqa(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& kv_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..d05f359f1d58b 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -329,6 +329,45 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_XQA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_xqa_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + +void reshape_and_cache_xqa( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& kv_cache, // [num_blocks, 2, num_heads, block_size, + // head_size], k_cache, v_cache + torch::Tensor& slot_mapping, // [num_tokens] k and v shared + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = kv_cache.size(3); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = kv_cache.stride(1); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_XQA); +} + namespace vllm { template diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7783acd741f5f..3b1f1d2151266 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -318,6 +318,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache_xqa(Tensor key, Tensor value," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + cache_ops.impl("reshape_and_cache_xqa", torch::kCUDA, + &reshape_and_cache_xqa); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 4f2f9cc3dac7d..c1d5fcdb72e8a 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,7 +1,8 @@ import pytest from vllm.utils import (create_kv_caches_with_random, - create_kv_caches_with_random_flash) + create_kv_caches_with_random_flash, + create_kv_caches_with_random_xqa) @pytest.fixture() @@ -12,3 +13,7 @@ def kv_cache_factory(): @pytest.fixture() def kv_cache_factory_flashinfer(): return create_kv_caches_with_random_flash + +@pytest.fixture() +def kv_cache_factory_xqa(): + return create_kv_caches_with_random_xqa diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 71d18359164b1..9013f84bf2e4e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -303,6 +303,102 @@ def test_reshape_and_cache_flash( torch.testing.assert_close(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_xqa( + kv_cache_factory_xqa, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + kv_caches = kv_cache_factory_xqa( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + kv_cache = kv_caches[0].contiguous() + del kv_caches + + # Clone the KV caches. + if kv_cache_dtype == "fp8": + cloned_kv_cache = torch.empty_like(kv_cache, dtype=torch.float16) + ops.convert_fp8(cloned_kv_cache, kv_cache) + else: + cloned_kv_cache = kv_cache.clone() + + # Using default kv_scale + k_scale = v_scale = 1.0 + + # Call the reshape_and_cache_xqa kernel. + ops.reshape_and_cache_xqa(key, value, kv_cache, + slot_mapping, kv_cache_dtype, k_scale, v_scale) + + if kv_cache_dtype == "fp8": + result_kv_cache = torch.empty_like(kv_cache, dtype=torch.float16) + ops.convert_fp8(result_kv_cache, kv_cache) + + # Run the reference implementation. + # kv_cache layout is [num_blocks, 2, num_heads, block_size, head_size] + block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_indicies_lst = block_indicies.cpu().tolist() + + block_offsets = slot_mapping % block_size + block_offsets_lst = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies_lst[i] + block_offset = block_offsets_lst[i] + cloned_kv_cache[block_idx, 0, :, block_offset, :] = key[i] + cloned_kv_cache[block_idx, 1, :, block_offset, :] = value[i] + + if kv_cache_dtype == "fp8": + torch.testing.assert_close(result_kv_cache, + cloned_kv_cache, + atol=0.001, + rtol=0.1) + else: + torch.testing.assert_close(kv_cache, cloned_kv_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe254732e7309..24b4bf4c9258a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -578,6 +578,20 @@ def reshape_and_cache_flash( v_scale) +def reshape_and_cache_xqa( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_xqa(key, value, kv_cache, + slot_mapping, kv_cache_dtype, + k_scale, v_scale) + + def copy_blocks(key_caches: List[torch.Tensor], value_caches: List[torch.Tensor], block_mapping: torch.Tensor) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index 657a3ecef696d..54e372d784906 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -623,6 +623,41 @@ def get_kv_cache_torch_dtype( return torch_dtype +def create_kv_caches_with_random_xqa( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, num_heads, block_size, head_size) + scale = head_size**-0.5 + + kv_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + kv_caches.append(key_value_cache) + return kv_caches + + def create_kv_caches_with_random_flash( num_blocks: int, block_size: int,