Skip to content

Commit

Permalink
Reshape cache to be xqa kernel compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Sep 5, 2024
1 parent 2ee4528 commit 76b2706
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 1 deletion.
6 changes: 6 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

void reshape_and_cache_xqa(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
39 changes: 39 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,45 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH);
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_XQA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_xqa_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);

void reshape_and_cache_xqa(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& kv_cache, // [num_blocks, 2, num_heads, block_size,
// head_size], k_cache, v_cache
torch::Tensor& slot_mapping, // [num_tokens] k and v shared
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = kv_cache.size(3);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = kv_cache.stride(1);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_XQA);
}

namespace vllm {

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
Expand Down
10 changes: 10 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);

// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_xqa(Tensor key, Tensor value,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache_xqa", torch::kCUDA,
&reshape_and_cache_xqa);

// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
Expand Down
7 changes: 6 additions & 1 deletion tests/kernels/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from vllm.utils import (create_kv_caches_with_random,
create_kv_caches_with_random_flash)
create_kv_caches_with_random_flash,
create_kv_caches_with_random_xqa)


@pytest.fixture()
Expand All @@ -12,3 +13,7 @@ def kv_cache_factory():
@pytest.fixture()
def kv_cache_factory_flashinfer():
return create_kv_caches_with_random_flash

@pytest.fixture()
def kv_cache_factory_xqa():
return create_kv_caches_with_random_xqa
96 changes: 96 additions & 0 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,102 @@ def test_reshape_and_cache_flash(
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_xqa(
kv_cache_factory_xqa,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)

qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
_, key, value = qkv.unbind(dim=1)

# Create the KV caches.
kv_caches = kv_cache_factory_xqa(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
kv_cache = kv_caches[0].contiguous()
del kv_caches

# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_kv_cache = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(cloned_kv_cache, kv_cache)
else:
cloned_kv_cache = kv_cache.clone()

# Using default kv_scale
k_scale = v_scale = 1.0

# Call the reshape_and_cache_xqa kernel.
ops.reshape_and_cache_xqa(key, value, kv_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)

if kv_cache_dtype == "fp8":
result_kv_cache = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(result_kv_cache, kv_cache)

# Run the reference implementation.
# kv_cache layout is [num_blocks, 2, num_heads, block_size, head_size]
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies_lst = block_indicies.cpu().tolist()

block_offsets = slot_mapping % block_size
block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_kv_cache[block_idx, 0, :, block_offset, :] = key[i]
cloned_kv_cache[block_idx, 1, :, block_offset, :] = value[i]

if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_kv_cache,
cloned_kv_cache,
atol=0.001,
rtol=0.1)
else:
torch.testing.assert_close(kv_cache, cloned_kv_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand Down
14 changes: 14 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,20 @@ def reshape_and_cache_flash(
v_scale)


def reshape_and_cache_xqa(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_xqa(key, value, kv_cache,
slot_mapping, kv_cache_dtype,
k_scale, v_scale)


def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
Expand Down
35 changes: 35 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,41 @@ def get_kv_cache_torch_dtype(
return torch_dtype


def create_kv_caches_with_random_xqa(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, num_heads, block_size, head_size)
scale = head_size**-0.5

kv_caches: List[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
kv_caches.append(key_value_cache)
return kv_caches


def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
Expand Down

0 comments on commit 76b2706

Please sign in to comment.