Skip to content

Commit

Permalink
KV Cache Improved Flexibility (microsoft#4668)
Browse files Browse the repository at this point in the history
This KV-cache adds the foundation for appropriately supporting two key
KV-cache improvements:

1. Delineation between local/dense KV caches/models at the cache level
in addition to the attention module level.
2. Support for multiple types of disjoint KV caches (such as alternating
local + dense attention GPT-Neo).

Follow up item: Determine appropriate statistics for weighting local +
dense KV block ratios when both are present.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
cmikeh2 and tjruwase authored Nov 14, 2023
1 parent 5411030 commit 901d807
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 181 deletions.
13 changes: 10 additions & 3 deletions deepspeed/inference/v2/engine_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ class InferenceEngineV2:
"""

@property
def free_blocks(self) -> int:
def free_blocks(self) -> torch.Tensor:
"""
Number of free KV blocks.
Number of free KV blocks. This is a tensor of shape [n_kv_cache_groups] where each
element is the number of free blocks in the corresponding KV cache group.
"""
return self._state_manager.free_blocks

@property
def n_kv_cache_groups(self) -> int:
"""
Number of KV cache groups.
"""
return self._state_manager.n_kv_cache_groups

def model(self) -> DSInferenceModelBase:
"""
The model implementation.
Expand Down Expand Up @@ -143,7 +150,7 @@ def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -

return logits

def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]:
def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, torch.Tensor]:
"""
Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
(this UID may not be recognized by the model yet), this will return the number of tokens
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def flattened_param_metadata(self) -> Optional[ModelMetadata]:

@abstractmethod
def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int,
max_new_blocks: int) -> Tuple[int, int]:
max_new_blocks: Tuple[int, ...]) -> Tuple[int, torch.Tensor]:
"""
Given a sequence and the number of new tokens in the sequence, determine the
number of new KV blocks needed to support the sequence. This method is
Expand All @@ -193,9 +193,9 @@ def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: in
max_new_blocks (int): Maximum number of blocks to hypothetically allocate.
Returns:
Tuple[int, int]: The tuple of number of tokens scheduled and number
of blocks allocated. In general, only one of these numbers will match the
corresponding input argument, but this is not guaranteed.
Tuple[int, torch.Tensor]: The tuple of number of tokens scheduled and number
of blocks allocated (per KV cache). In general, only one of these numbers will
match the corresponding input argument, but this is not guaranteed.
"""
raise NotImplementedError()

Expand All @@ -212,9 +212,10 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -
raise NotImplementedError()

@abstractmethod
def kv_cache_config(self) -> KVCacheConfig:
def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]:
"""
Return the KV-cache configuration for this model.
Return the KV-cache configuration for this model. This should be a tuple of one or more
KVCacheConfig objects (one for each distinct cache group).
"""
raise NotImplementedError()

Expand All @@ -226,7 +227,7 @@ def max_sequence_length(self) -> int:
"""
...

def maybe_free_kv(self, sequence: DSSequenceDescriptor):
def maybe_free_kv(self, sequence: DSSequenceDescriptor) -> None:
"""
After completing a forward pass, determine whether or not the there are any KV blocks
that maybe freed since they are no longer in use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def make_attn_layer(self) -> None:
self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)

def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int,
max_new_blocks: int) -> Tuple[int, int]:
max_new_blocks: int) -> Tuple[int, torch.Tensor]:
"""
See ``DSInferenceModelBase.get_kv_requirements`` for documentation.
Expand All @@ -341,7 +341,7 @@ def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: in
token_capacity = (max_new_blocks +
sequence.cur_allocated_blocks) * self.attn.kv_block_size - sequence.seen_tokens

return token_capacity, max_new_blocks
return token_capacity, torch.tensor([max_new_blocks])

def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None:
"""
Expand All @@ -356,7 +356,7 @@ def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -
new_blocks = self.state_manager.allocate_blocks(n_needed_blocks)
sequence.extend_kv_cache(new_blocks)

def kv_cache_config(self) -> KVCacheConfig:
def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]:
"""
See ``DSInferenceModelBase.kv_cache_config`` for documentation.
Expand All @@ -370,7 +370,7 @@ def kv_cache_config(self) -> KVCacheConfig:
cache_shape=cache_shape,
cache_dtype=self.activation_dtype,
max_blocks_per_allocation_group=max_blocks)
return self._kv_cache_config
return (self._kv_cache_config, )

def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None:
"""
Expand Down
112 changes: 80 additions & 32 deletions deepspeed/inference/v2/ragged/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,26 @@ def split_kv(kv_cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

class BlockedKVCache:

_caches: torch.Tensor
_caches: Tuple[torch.Tensor, ...]
"""
Backing storage for all KV caches. This is a 6D tensor with the following shape:
(num_caches, num_blocks, block_size, 2, num_heads, head_size)
"""

_allocator: BlockedAllocator
_allocators: Tuple[BlockedAllocator, ...]
"""
Block allocator for tracking cache usage. This manages the GPU cache.
"""

_config: KVCacheConfig
_configs: Tuple[KVCacheConfig, ...]
"""
Configuration of the KV cache. See ``KVCacheConfig`` for more details.
Configuration of the KV cache(s). See ``KVCacheConfig`` for more details. This enables the support
for different types/shapes of KV-caches (i.e. the alternating local and global attention in
GPT-Neo).
"""

def __init__(self,
config: KVCacheConfig,
configs: Tuple[KVCacheConfig, ...],
memory_config: MemoryConfig,
mp_group: Optional[Any] = None,
offload: bool = False) -> None:
Expand All @@ -71,17 +73,21 @@ def __init__(self,
blocks (int): The number of blocks to pre-allocate for the cache. If this is set,
slack will be ignored.
"""
self._config = config
self._configs = configs
self._memory_config = memory_config
self._enable_offload = offload

if self._enable_offload:
raise NotImplementedError("Offloading of KV-caches is not yet supported.")

if AllocationMode(self._memory_config.mode) is AllocationMode.RESERVE:
per_block_footprint = reduce(operator.mul, self._config.cache_shape, self._config.block_size)
per_block_footprint *= 2 # for key and value
per_block_footprint *= elem_size(self._config.cache_dtype)
# TODO(cmikeh2): Change the weighting based on the type of the KV-cache

total_per_block_footprint = 0
for config in self._configs:
per_block_footprint = reduce(operator.mul, config.cache_shape, config.block_size)
per_block_footprint *= 2 # for key and value
total_per_block_footprint += per_block_footprint * elem_size(config.cache_dtype)

# Perform a dummy nccl call before calculating available memory, on some systems (H100) we've observed higher memory allocations from NCCL
if dist.get_world_size(group=mp_group) > 1:
Expand All @@ -93,15 +99,15 @@ def __init__(self,
total_memory = get_accelerator().total_memory()

inference_logger().debug(
f"Memory usage before KV-cache allocation: total_memory={total_memory}, available_kv_memory={available_kv_memory}, per_block_footprint={per_block_footprint}"
f"Memory usage before KV-cache allocation: total_memory={total_memory}, available_kv_memory={available_kv_memory}, total_per_block_footprint={total_per_block_footprint}"
)

if available_kv_memory < per_block_footprint:
if available_kv_memory < total_per_block_footprint:
raise ValueError(
f"Insufficient memory to allocate KV-caches. Required: {per_block_footprint}, Available: {available_kv_memory}"
f"Insufficient memory to allocate KV-caches. Required: {total_per_block_footprint}, Available: {available_kv_memory}"
)

num_blocks = available_kv_memory // per_block_footprint
num_blocks = available_kv_memory // total_per_block_footprint

# In a multi-process setting, we need to ensure that all processes have the same
# KV cache capacity to ensure scheduling guarantees are equivalent on all ranks.
Expand All @@ -117,49 +123,91 @@ def __init__(self,
else: # AllocationMode.ALLOCATE
num_blocks = self._memory_config.size

num_caches = self._config.cache_shape[0]
num_heads = self._config.cache_shape[1]
head_size = self._config.cache_shape[2]
caches = []
allocators = []

for cache_group_id, config in enumerate(self._configs):
num_caches = config.cache_shape[0]
num_heads = config.cache_shape[1]
head_size = config.cache_shape[2]

alloc_shape = (num_caches, num_blocks, config.block_size, 2, num_heads, head_size)
inference_logger().info(
f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.")
caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype,
device=get_accelerator().current_device()))
allocators.append(BlockedAllocator(num_blocks))

alloc_shape = (num_caches, num_blocks, self._config.block_size, 2, num_heads, head_size)
inference_logger().info(f"Allocating KV-cache with shape: {alloc_shape} consisting of {num_blocks} blocks.")
self._caches = torch.empty(alloc_shape,
dtype=self._config.cache_dtype,
device=get_accelerator().current_device())
self._allocator = BlockedAllocator(num_blocks)
self._caches = tuple(caches)
self._allocators = tuple(allocators)
self._free_blocks = torch.empty(len(self._allocators), dtype=torch.int32, device="cpu")
for i, allocator in enumerate(self._allocators):
self._free_blocks[i] = allocator.free_blocks

def reserve(self, num_blocks: int) -> torch.Tensor:
def reserve(self, num_blocks: int, cache_group: int = 0) -> torch.Tensor:
"""
Reserve a number of blocks from the cache. This will return a 1D tensor of
block_ids that have been marked as reserved.
Parameters:
num_blocks (int): The number of blocks to reserve.
cache_group (int): The cache group to reserve from. Default is 0.
"""
return self._allocator.allocate(num_blocks)
return self._allocators[cache_group].allocate(num_blocks)

def free(self, blocks: Iterable[int]) -> None:
def free(self, blocks: Iterable[int], cache_group: int = 0) -> None:
"""
Free a set of blocks from the cache. This will mark the blocks as free in the
allocator.
Parameters:
blocks (Iterable[int]): The blocks to free.
cache_group (int): The cache group to free from. Default is 0.
"""
self._allocator.free(blocks)
self._allocators[cache_group].free(blocks)

def offload(self, blocks: Iterable[int]) -> torch.Tensor:
def offload(self, blocks: Iterable[int], cache_group: int = 0) -> torch.Tensor:
"""
Offload KV-cache blocks from accelerator memory to the host.
Parameters:
blocks (Iterable[int]): The blocks to offload.
cache_group (int): The cache group to offload from. Default is 0.
"""
raise NotImplementedError("Offloading is not yet supported.")

def restore(self, blocks: Iterable[int]) -> torch.Tensor:
def restore(self, blocks: Iterable[int], cache_group: int = 0) -> torch.Tensor:
"""
Restore KV-cache blocks from the host to accelerator memory.
Parameters:
blocks (Iterable[int]): The blocks to restore.
cache_group (int): The cache group to restore to. Default is 0.
"""
raise NotImplementedError("Offloading is not yet supported.")

def get_cache(self, cache_id: int) -> torch.Tensor:
def get_cache(self, cache_id: int, cache_group: int = 0) -> torch.Tensor:
"""
Get the tensor associated with the given cache ID.
Parameters:
cache_id (int): The ID of the cache tensor to get.
cache_group (int): The cache group to get from. Default is 0.
"""
return self._caches[cache_id]
return self._caches[cache_group][cache_id]

@property
def free_blocks(self):
return self._allocator.free_blocks
def free_blocks(self) -> torch.Tensor:
"""
Return the number of free blocks in each cache
"""
for i, allocator in enumerate(self._allocators):
self._free_blocks[i] = allocator.free_blocks
return self._free_blocks

@property
def num_caches(self) -> int:
"""
Return the number of caches
"""
return len(self._caches)
19 changes: 19 additions & 0 deletions deepspeed/inference/v2/ragged/manager_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,27 @@
from ..inference_utils import DtypeEnum


class KVCacheType(Enum):

DENSE = "dense"
"""
Dense KV-cache. This is the default type.
"""

LOCAL = "local"
"""
KV-cache that attends to only a local (trailing) window of tokens.
"""


class KVCacheConfig(DeepSpeedConfigModel):

type: KVCacheType = KVCacheType.DENSE
"""
Type of KV-cache to use. This may inform the allocator of the expected access/retention pattern
to enable more efficient memory management.
"""

block_size: int = 128
"""
Number of tokens that may be contained in each cache block.
Expand Down
Loading

0 comments on commit 901d807

Please sign in to comment.