Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add metrics for prefix cache hit rate #829

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aphrodite/engine/aphrodite_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ExecuteModelRequest, PoolerOutput,
SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from aphrodite.common.utils import Counter
from aphrodite.common.utils import Counter, Device
from aphrodite.engine.args_tools import EngineArgs
from aphrodite.engine.metrics_types import StatLoggerBase, Stats
from aphrodite.engine.output_processor.interfaces import (
Expand Down Expand Up @@ -1290,6 +1290,13 @@ def _get_stats(
else:
cpu_cache_usage_sys = 0.0

# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.CPU)
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)

# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
Expand Down Expand Up @@ -1400,6 +1407,10 @@ def _get_stats(
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,

# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,

# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
Expand Down
23 changes: 23 additions & 0 deletions aphrodite/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def __init__(self, labelnames: List[str], max_model_len: int):
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames,
multiprocess_mode="sum")

# Prefix caching block hit rate
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
name="aphrodite:cpu_prefix_cache_hit_rate",
documentation="CPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
name="aphrodite:gpu_prefix_cache_hit_rate",
documentation="GPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")

# Iteration stats
self.counter_num_preemption = self._counter_cls(
Expand Down Expand Up @@ -347,6 +359,13 @@ def log(self, stats: Stats) -> None:
f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%."
)

if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
logger.info(
"Prefix cache hit rate: "
f"GPU: {stats.gpu_prefix_cache_hit_rate * 100:.2f}%, "
f"CPU: {stats.cpu_prefix_cache_hit_rate * 100:.2f}%")

if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
Expand Down Expand Up @@ -418,6 +437,10 @@ def _log_prometheus(self, stats: Stats) -> None:
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)

# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
Expand Down
3 changes: 3 additions & 0 deletions aphrodite/engine/metrics_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Stats:
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
Expand Down
49 changes: 49 additions & 0 deletions aphrodite/processing/block/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple

from aphrodite.processing.block.interfaces import Block, BlockAllocator
Expand Down Expand Up @@ -282,6 +283,54 @@ def ids(self) -> List[int]:
return self._block_ids


@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000
def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0
# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1
def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0
completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks


def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.

Expand Down
5 changes: 5 additions & 0 deletions aphrodite/processing/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ def get_common_computed_block_ids(
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())

def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Expand Down
10 changes: 10 additions & 0 deletions aphrodite/processing/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def get_num_blocks_touched(self,
num_lookahead_slots: int = 0) -> int:
pass

@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block:
There is at most one null block per allocator.
"""
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
3 changes: 3 additions & 0 deletions aphrodite/processing/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ def swap_in(self, blocks: List[Block]) -> None:

block.block_id = block_id # Assign block_id

def get_prefix_cache_hit_rate(self) -> float:
return -1


class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
Expand Down
10 changes: 9 additions & 1 deletion aphrodite/processing/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple

from aphrodite.common.utils import cdiv
from aphrodite.processing.block.common import (CopyOnWriteTracker,
from aphrodite.processing.block.common import (CacheMetricData,
CopyOnWriteTracker,
get_all_blocks_recursively)
from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
BlockId, Device)
Expand Down Expand Up @@ -109,6 +110,8 @@ def __init__(
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())

self.metric_data = CacheMetricData()

# Implements Block.Factory.
def _create_block(
self,
Expand Down Expand Up @@ -157,9 +160,11 @@ def allocate_immutable_block(self,

cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)

# No cached block => Allocate a new block
Expand Down Expand Up @@ -406,6 +411,9 @@ def get_physical_block_id(self, absolute_id: int) -> int:
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids

def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()

def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
Expand Down
30 changes: 26 additions & 4 deletions aphrodite/processing/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aphrodite.common.block import BlockTable, PhysicalTokenBlock
from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
from aphrodite.common.utils import Device
from aphrodite.processing.block.common import CacheMetricData
from aphrodite.processing.block.utils import (
check_no_caching_or_swa_for_blockmgr_encdec)
from aphrodite.processing.evictor_v1 import (EvictionPolicy, Evictor,
Expand Down Expand Up @@ -62,6 +63,12 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass


@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass


class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.

Expand All @@ -86,6 +93,8 @@ def __init__(self,

self.default_hash_ctr = count()

self.cache_metric_data = CacheMetricData()

def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
Expand All @@ -111,10 +120,10 @@ def allocate(self,
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
Expand Down Expand Up @@ -151,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block

def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()


class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
Expand Down Expand Up @@ -210,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")

def get_prefix_cache_hit_rate(self) -> float:
return -1


class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
Expand Down Expand Up @@ -706,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")
3 changes: 3 additions & 0 deletions aphrodite/processing/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ def get_num_free_gpu_blocks(self) -> int:
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
Expand Down
16 changes: 9 additions & 7 deletions aphrodite/processing/evictor_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,21 @@ def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")

evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
evicted_block, evicted_block_id = None, None
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed:
break
if (evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block, evicted_block_id = block, _id

assert evicted_block is not None
assert evicted_block_id is not None

self.free_table.pop(evicted_block_id)

Expand All @@ -110,7 +113,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,

def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)

def remove(self, block_id: int):
if block_id not in self.free_table:
Expand Down
6 changes: 6 additions & 0 deletions aphrodite/processing/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple

from aphrodite.common.sequence import Sequence, SequenceGroup
from aphrodite.common.utils import Device


class AllocStatus(enum.Enum):
Expand Down Expand Up @@ -118,3 +119,8 @@ def get_common_computed_block_ids(
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
4 changes: 4 additions & 0 deletions aphrodite/processing/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Tuple

from aphrodite.common.sequence import Sequence, SequenceGroup
from aphrodite.common.utils import Device
from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager


Expand Down Expand Up @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self,

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
Loading
Loading