diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index 5bef62c75..4140bdd12 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -23,7 +23,7 @@ ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from aphrodite.common.utils import Counter +from aphrodite.common.utils import Counter, Device from aphrodite.engine.args_tools import EngineArgs from aphrodite.engine.metrics_types import StatLoggerBase, Stats from aphrodite.engine.output_processor.interfaces import ( @@ -1290,6 +1290,13 @@ def _get_stats( else: cpu_cache_usage_sys = 0.0 + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1400,6 +1407,10 @@ def _get_stats( gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, num_generation_tokens_iter=num_generation_tokens_iter, diff --git a/aphrodite/engine/metrics.py b/aphrodite/engine/metrics.py index 8daecce32..acd19585c 100644 --- a/aphrodite/engine/metrics.py +++ b/aphrodite/engine/metrics.py @@ -70,6 +70,18 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="aphrodite:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="aphrodite:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -347,6 +359,13 @@ def log(self, stats: Stats) -> None: f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%." ) + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: " + f"GPU: {stats.gpu_prefix_cache_hit_rate * 100:.2f}%, " + f"CPU: {stats.cpu_prefix_cache_hit_rate * 100:.2f}%") + if self.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( @@ -418,6 +437,10 @@ def _log_prometheus(self, stats: Stats) -> None: stats.gpu_cache_usage_sys) self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, diff --git a/aphrodite/engine/metrics_types.py b/aphrodite/engine/metrics_types.py index 5d337f2e1..c150469f4 100644 --- a/aphrodite/engine/metrics_types.py +++ b/aphrodite/engine/metrics_types.py @@ -28,6 +28,9 @@ class Stats: # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int num_generation_tokens_iter: int diff --git a/aphrodite/processing/block/common.py b/aphrodite/processing/block/common.py index 14164376d..4dc0ffce5 100644 --- a/aphrodite/processing/block/common.py +++ b/aphrodite/processing/block/common.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from aphrodite.processing.block.interfaces import Block, BlockAllocator @@ -282,6 +283,54 @@ def ids(self) -> List[int]: return self._block_ids +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/aphrodite/processing/block/cpu_gpu_block_allocator.py b/aphrodite/processing/block/cpu_gpu_block_allocator.py index 7c8c7a270..0c9bd1db9 100644 --- a/aphrodite/processing/block/cpu_gpu_block_allocator.py +++ b/aphrodite/processing/block/cpu_gpu_block_allocator.py @@ -326,6 +326,11 @@ def get_common_computed_block_ids( def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/aphrodite/processing/block/interfaces.py b/aphrodite/processing/block/interfaces.py index cadf247cc..0f13db929 100644 --- a/aphrodite/processing/block/interfaces.py +++ b/aphrodite/processing/block/interfaces.py @@ -186,6 +186,11 @@ def get_num_blocks_touched(self, num_lookahead_slots: int = 0) -> int: pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block: There is at most one null block per allocator. """ pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/aphrodite/processing/block/naive_block.py b/aphrodite/processing/block/naive_block.py index 9610df9df..82eed6784 100644 --- a/aphrodite/processing/block/naive_block.py +++ b/aphrodite/processing/block/naive_block.py @@ -343,6 +343,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/aphrodite/processing/block/prefix_caching_block.py b/aphrodite/processing/block/prefix_caching_block.py index 797c1ce3e..59f034af1 100644 --- a/aphrodite/processing/block/prefix_caching_block.py +++ b/aphrodite/processing/block/prefix_caching_block.py @@ -4,7 +4,8 @@ from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from aphrodite.common.utils import cdiv -from aphrodite.processing.block.common import (CopyOnWriteTracker, +from aphrodite.processing.block.common import (CacheMetricData, + CopyOnWriteTracker, get_all_blocks_recursively) from aphrodite.processing.block.interfaces import (Block, BlockAllocator, BlockId, Device) @@ -109,6 +110,8 @@ def __init__( self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.metric_data = CacheMetricData() + # Implements Block.Factory. def _create_block( self, @@ -157,9 +160,11 @@ def allocate_immutable_block(self, cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: + self.metric_data.query(hit=True) block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block + self.metric_data.query(hit=False) self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -406,6 +411,9 @@ def get_physical_block_id(self, absolute_id: int) -> int: def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None if block.content_hash in self._cached_blocks: diff --git a/aphrodite/processing/block_manager_v1.py b/aphrodite/processing/block_manager_v1.py index bbaf0706e..6e573671a 100644 --- a/aphrodite/processing/block_manager_v1.py +++ b/aphrodite/processing/block_manager_v1.py @@ -12,6 +12,7 @@ from aphrodite.common.block import BlockTable, PhysicalTokenBlock from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus from aphrodite.common.utils import Device +from aphrodite.processing.block.common import CacheMetricData from aphrodite.processing.block.utils import ( check_no_caching_or_swa_for_blockmgr_encdec) from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor, @@ -62,6 +63,12 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + class CachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -86,6 +93,8 @@ def __init__(self, self.default_hash_ctr = count() + self.cache_metric_data = CacheMetricData() + def allocate_block(self, block_hash: int, num_hashed_tokens: int) -> PhysicalTokenBlock: if self.current_num_blocks == self.num_blocks: @@ -111,10 +120,10 @@ def allocate(self, block = self.evictor.remove(block_hash) assert block.ref_count == 0 self.cached_blocks[block_hash] = block - block.ref_count += 1 - assert block.block_hash == block_hash - return block - if block_hash not in self.cached_blocks: + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) self.cached_blocks[block_hash] = self.allocate_block( block_hash, num_hashed_tokens) block = self.cached_blocks[block_hash] @@ -151,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): del self.cached_blocks[old_hash] self.cached_blocks[block_hash] = block + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + class UncachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -210,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): raise NotImplementedError( "Invalid codepath for uncached block allocator.") + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" @@ -706,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/aphrodite/processing/block_manager_v2.py b/aphrodite/processing/block_manager_v2.py index 8b16e2e8f..38154fa79 100644 --- a/aphrodite/processing/block_manager_v2.py +++ b/aphrodite/processing/block_manager_v2.py @@ -439,6 +439,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/aphrodite/processing/evictor_v2.py b/aphrodite/processing/evictor_v2.py index 5b1a208b7..8dff37e8d 100644 --- a/aphrodite/processing/evictor_v2.py +++ b/aphrodite/processing/evictor_v2.py @@ -85,18 +85,21 @@ def evict(self) -> Tuple[int, int]: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block_id = next(iter(self.free_table.keys())) + evicted_block, evicted_block_id = None, None # The blocks with the lowest timestamps should be placed consecutively # at the start of OrderedDict. Loop through all these blocks to # find the one with maximum number of hashed tokens. for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue if evicted_block.last_accessed < block.last_accessed: break - if (evicted_block.last_accessed == block.last_accessed and - evicted_block.num_hashed_tokens < block.num_hashed_tokens): - evicted_block = block - evicted_block_id = _id + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + + assert evicted_block is not None + assert evicted_block_id is not None self.free_table.pop(evicted_block_id) @@ -110,7 +113,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed - self.free_table.move_to_end(block_id) def remove(self, block_id: int): if block_id not in self.free_table: diff --git a/aphrodite/processing/interfaces.py b/aphrodite/processing/interfaces.py index 890db0747..8e61ddfe6 100644 --- a/aphrodite/processing/interfaces.py +++ b/aphrodite/processing/interfaces.py @@ -5,6 +5,7 @@ from typing import Tuple from aphrodite.common.sequence import Sequence, SequenceGroup +from aphrodite.common.utils import Device class AllocStatus(enum.Enum): @@ -118,3 +119,8 @@ def get_common_computed_block_ids( @abstractmethod def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/aphrodite/processing/placeholder_block_space_manager.py b/aphrodite/processing/placeholder_block_space_manager.py index 3e0f08cd4..6abc5a1d0 100644 --- a/aphrodite/processing/placeholder_block_space_manager.py +++ b/aphrodite/processing/placeholder_block_space_manager.py @@ -1,6 +1,7 @@ from typing import List, Tuple from aphrodite.common.sequence import Sequence, SequenceGroup +from aphrodite.common.utils import Device from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self, def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/aphrodite/processing/scheduler.py b/aphrodite/processing/scheduler.py index f3033e1d1..5159265df 100644 --- a/aphrodite/processing/scheduler.py +++ b/aphrodite/processing/scheduler.py @@ -13,7 +13,7 @@ SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from aphrodite.common.utils import PyObjectCache +from aphrodite.common.utils import Device, PyObjectCache from aphrodite.lora.request import LoRARequest from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager from aphrodite.prompt_adapter.request import PromptAdapterRequest @@ -457,6 +457,9 @@ def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 22a45ee58..b856b58a6 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -682,6 +682,29 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): assert new_block[0].block_id == last_block_id + # Test case for cache mertics + @staticmethod + def test_metric(): + block_size = 16 + allocator = PrefixCachingBlockAllocator(num_blocks=4, + block_size=block_size) + # Test when no query (0/0) + assert allocator.get_prefix_cache_hit_rate() == 0.0 + token_ids = list(range(block_size)) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 0/1 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.0 + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 1/2 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.5 + # Test more than one block + for _ in range(2, 1005): + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + assert allocator.get_prefix_cache_hit_rate() > 0.99 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index f12f92ffa..96f0643f8 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -34,6 +34,9 @@ def test_block_allocator( assert (first_block == second_block) assert (second_block.ref_count == 2) + # Check metric: 1 hit of 2 queries + assert block_allocator.get_prefix_cache_hit_rate() == 0.5 + # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) @@ -48,6 +51,10 @@ def test_block_allocator( assert (first_block == second_block) assert (first_block.block_hash == block_hash) + # Allocate one more time to get 3/4 hit rate for easy checking + block_allocator.allocate(block_hash, 0) + assert block_allocator.get_prefix_cache_hit_rate() == 0.75 + @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ):