diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index ed7e06cab2996..44adc4158abec 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -1,6 +1,7 @@ import enum +import heapq from abc import ABC, abstractmethod -from typing import OrderedDict, Tuple +from typing import Dict, List, Tuple class EvictionPolicy(enum.Enum): @@ -75,8 +76,14 @@ class LRUEvictor(Evictor): highest num_hashed_tokens value, then one will be chose arbitrarily """ + # CLEANUP_THRESHOLD determines the maximum allowable size of the priority + # queue relative to the free table size. When this threshold is exceeded, + # a cleanup operation is triggered to reduce memory usage. + CLEANUP_THRESHOLD = 50 + def __init__(self): - self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + self.free_table: Dict[int, BlockMetaData] = {} + self.priority_queue = [] def __contains__(self, block_id: int) -> bool: return block_id in self.free_table @@ -85,34 +92,50 @@ def evict(self) -> Tuple[int, int]: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block, evicted_block_id = None, None - # The blocks with the lowest timestamps should be placed consecutively - # at the start of OrderedDict. Loop through all these blocks to - # find the one with maximum number of hashed tokens. - for _id, block in self.free_table.items(): - if evicted_block is None: - evicted_block, evicted_block_id = block, _id - continue - if evicted_block.last_accessed < block.last_accessed: - break - if evicted_block.num_hashed_tokens < block.num_hashed_tokens: - evicted_block, evicted_block_id = block, _id - - assert evicted_block is not None - assert evicted_block_id is not None - self.free_table.pop(evicted_block_id) - - return evicted_block_id, evicted_block.content_hash + while self.priority_queue: + # We do not remove outdated entries from the priority queue at the + # time of updating the last_accessed timestamp. Instead, outdated + # entries are filtered out here during eviction. Outdated entries + # would either not in the free table, or have older last accessed + # time. + last_accessed, _, block_id, content_hash = heapq.heappop( + self.priority_queue) + if (block_id in self.free_table and + self.free_table[block_id].last_accessed == last_accessed): + self.free_table.pop(block_id) + return block_id, content_hash + + raise ValueError("No usable cache memory left") def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, last_accessed: float): self.free_table[block_id] = BlockMetaData(content_hash, num_hashed_tokens, last_accessed) + heapq.heappush( + self.priority_queue, + (last_accessed, -num_hashed_tokens, block_id, content_hash)) + self._cleanup_if_necessary() def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed + def _cleanup_if_necessary(self): + if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( + self.free_table): + self._cleanup() + + def _cleanup(self): + new_priority_queue: List[Tuple[float, int, int, int]] = [] + + for block_id, block in self.free_table.items(): + new_priority_queue.append( + (block.last_accessed, -block.num_hashed_tokens, block_id, + block.content_hash)) + heapq.heapify(new_priority_queue) + + self.priority_queue = new_priority_queue + def remove(self, block_id: int): if block_id not in self.free_table: raise ValueError(