diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d10cb29ef4a7c..27a2fb3f08af8 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -189,6 +189,16 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: self._allocator.allocate_mutable_block( prev_block=self._blocks[-1], device=device)) + def has_enough_empty_slots(self, num_empty_slots: int) -> bool: + """Return if the BlockTable has at least the specified number of + empty slots available. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + assert self._is_allocated + + return self._num_empty_slots >= num_empty_slots + def fork(self) -> "BlockTable": """Creates a new BlockTable instance with a copy of the blocks from the current instance. diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 209487c6b4f9e..fdabdbb697b29 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -1,5 +1,5 @@ """A block manager that manages token blocks.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from typing import Sequence as GenericSequence from typing import Tuple @@ -226,6 +226,27 @@ def can_append_slots(self, seq_group: SequenceGroup, Device.GPU) return num_touched_blocks <= num_free_gpu_blocks + def can_add_slots(self, seq_groups: Union[SequenceGroup, List[SequenceGroup]], + num_lookahead_slots: int) -> bool: + """ If the empty slot is less then lookahead_slots, a new block should be allocated. + Determine if there is enough space in GPU KV cache for the num_lookahead_slots. + """ + + if not isinstance(seq_groups,list): + seq_groups = [seq_groups] + + num_touched_blocks = 0 + for seq_group in seq_groups: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += 0 if ( + block_table.has_enough_empty_slots(num_lookahead_slots)) else 1 + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + return num_touched_blocks <= num_free_gpu_blocks + def append_slots( self, seq: Sequence, @@ -243,6 +264,10 @@ def append_slots( new_cows = self.block_allocator.clear_copy_on_writes() return new_cows + def add_lookahead_slots(self, seq: Sequence, num_lookahead_slots: int): + block_table = self.block_tables[seq.seq_id] + block_table.ensure_num_empty_slots(num_lookahead_slots) + def free(self, seq: Sequence) -> None: seq_id = seq.seq_id diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..63da20a0d004e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -165,6 +165,11 @@ def is_empty(self) -> bool: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def is_any_finished(self) -> bool: + return (not self.scheduled_seq_groups or + any(sg.seq_group.is_finished() + for sg in self.scheduled_seq_groups)) + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, @@ -492,6 +497,9 @@ def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def has_pending_seqs(self) -> bool: + return len(self.waiting) != 0 or len(self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) @@ -1231,7 +1239,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_lookahead_slots=num_lookahead_slots, running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), + len(running_scheduled.swapped_out)) ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ecc222f692c41..a83bbf87740c3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,6 +16,8 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.core.block_manager import SelfAttnBlockSpaceManager +from vllm.core.placeholder_block_space_manager import PlaceholderBlockSpaceManager from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -59,7 +61,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind, weak_bind_with_ret from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -133,6 +135,15 @@ def append_output(self, outputs: List[SamplerOutput], is_first_step_output=is_first_step_output, skip=[])) + def is_any_finished(self) -> bool: + return self.scheduler_outputs.is_any_finished() + + def get_running_reqs(self) -> int: + running_seqs = 0 + for scheduled_seq_group in self.scheduler_outputs.scheduled_seq_groups: + running_seqs += len(scheduled_seq_group. + seq_group.get_seqs(status=SequenceStatus.RUNNING)) + return running_seqs class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -417,6 +428,15 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for v_id in range(self.parallel_config.pipeline_parallel_size) ] + multi_step_modify = weak_bind_with_ret(self._verify_and_add_multi_step) + + self.multi_step_modify_callback = [ + partial(multi_step_modify, + block_manager=self.scheduler[v_id].block_manager, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + # Metric Logging. if self.log_stats: if stat_loggers is not None: @@ -1437,11 +1457,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: finished_requests_ids=finished_requests_ids, # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) + last_sampled_token_ids=last_sampled_token_ids, + has_pending_reqs=self.scheduler[virtual_engine].has_pending_seqs()) if allow_async_output_proc: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] + execute_model_req.multi_step_modify_callback \ + = self.multi_step_modify_callback[virtual_engine] outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -2079,3 +2102,36 @@ def _build_logits_processors( sampling_params.logits_processors.extend(logits_processors) return sampling_params + + def _verify_and_add_multi_step(self, + block_manager: Union[ + SelfAttnBlockSpaceManager, PlaceholderBlockSpaceManager], + ctx: SchedulerContext) -> bool: + """ Callback for NRNS to determine if the multi-step can increase and + then execute the increment. + If the block is enough for the increment, add the lookahead_slots of each seq, + and also update the num_steps of each seq_group_metadata. + Return true when the increment succeeds, + """ + if len(ctx.output_queue) == 0: + return False + scheduler_step = ctx.seq_group_metadata_list[0].state.num_steps + # if multi-step is 1, the multi-step is off, just return false. + if scheduler_step == 1: + return False + sequence_groups = [scheduled_seq_group.seq_group for scheduled_seq_group + in ctx.scheduler_outputs.scheduled_seq_groups] + + # For prefill + enable_chunked, num_lookahead_slots = scheduler_step, + # For Decode, num_lookahead_slots = scheduler_step - 1. + # Here consider the worst case, assume num_lookahead_slots = scheduler_step, + # Thus the new num_lookahead_slots should be num_lookahead_slots + 1 + if block_manager.can_add_slots(seq_groups=sequence_groups, + num_lookahead_slots=scheduler_step+1): + for sequence_group in sequence_groups: + for seq in sequence_group.seqs: + block_manager.add_lookahead_slots(seq, scheduler_step+1) + for seq_group_metadata in ctx.seq_group_metadata_list: + seq_group_metadata.add_step() + return True + return False diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..2a5c490041619 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -607,6 +607,9 @@ class SequenceGroupState(msgspec.Struct, def remaining_steps(self) -> int: return self.num_steps - self.current_step + def add_step(self): + self.num_steps = self.num_steps + 1 + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -964,6 +967,10 @@ def is_single_step_prompt(self) -> bool: # step. return self.is_prompt and self.do_sample + def add_step(self): + if self.state is not None: + self.state.add_step() + def get_first_seq_id(self) -> int: # This is an efficient way of fetching the seq_id when # we know this SequenceGroup has only one sequence. @@ -1265,6 +1272,13 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + multi_step_modify_callback: Optional[Callable[..., bool]] = None + has_pending_reqs: bool = True + + def add_step(self): + if self.seq_group_metadata_list is not None: + for seq_group_metadata in self.seq_group_metadata_list: + seq_group_metadata.add_step() @property def is_first_multi_step(self) -> bool: diff --git a/vllm/utils.py b/vllm/utils.py index 6f7a6f8c54e47..aff69b7f866ac 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -281,12 +281,12 @@ class PyObjectCache: across scheduler iterations. """ - def __init__(self, obj_builder): + def __init__(self, obj_builder, init_size: int = 128): self._obj_builder = obj_builder self._index = 0 self._obj_cache = [] - for _ in range(128): + for _ in range(init_size): self._obj_cache.append(self._obj_builder()) def _grow_cache(self): @@ -313,6 +313,9 @@ def reset(self): """ self._index = 0 + def get_remain_index(self) -> int: + return len(self._obj_cache) - self._index + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: @@ -1161,6 +1164,19 @@ def weak_bound(*args, **kwargs) -> None: return weak_bound +def weak_bind_with_ret(bound_method: Callable[..., Any], ) -> Callable[..., Any]: + """Make an instance method with return that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> Any: + if inst := ref(): + return unbound(inst, *args, **kwargs) + + return weak_bound + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1f654a9cce465..1c63f6caaeac9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -105,6 +105,7 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None + multi_step_modify_callback: Optional[Callable[..., bool]] = None, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 3ee0fb4dc943e..2a926e8dcbe20 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -21,6 +21,7 @@ BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, _init_frozen_model_input_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from ..config import CacheConfig from ..model_executor.model_loader.tensorizer import TensorizerConfig @@ -54,14 +55,20 @@ def completion_seq_group_output_builder(): class PythonizationCache: def __init__(self): - self.cached_seq_output = PyObjectCache(seq_output_builder) + # I set the initila size here manually by the size of gpu block(21494) + # TODO: set the initial_size by the cache_config.num_gpu_blocks + self.cached_seq_output = PyObjectCache(seq_output_builder, 22000 * 16) self.cached_completion_seq_group_output = PyObjectCache( - completion_seq_group_output_builder) + completion_seq_group_output_builder, 20000 * 16) def reset(self): self.cached_seq_output.reset() self.cached_completion_seq_group_output.reset() + def get_remain_cache(self) -> int: + return min(self.cached_seq_output.get_remain_index(), + self.cached_completion_seq_group_output.get_remain_index()) + @dataclass class ModelOutput: @@ -159,6 +166,7 @@ class StatefulModelInput(BroadcastableModelInput): num_seqs: int = -1 num_queries: int = -1 num_single_step_prefills: int = 0 + has_pending_reqs: bool = True def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None @@ -405,10 +413,13 @@ def _async_process_outputs(self, model_input: StatefulModelInput, break def _final_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: Optional[Callable]): + output_proc_callback: Optional[Callable], + multi_step_modify_callback: Optional[Callable[..., bool]], + pythonization_cache: Optional[PythonizationCache] = None): assert model_input.frozen_model_input is not None - has_async_callback = output_proc_callback is not None + has_async_output_proc_callback = output_proc_callback is not None + has_async_multi_step_modify_callback = multi_step_modify_callback is not None outputs = [] for step_num, output in enumerate(model_input.cached_outputs): @@ -419,9 +430,32 @@ def _final_process_outputs(self, model_input: StatefulModelInput, # For async case: # -- Invoke callback, pythonize, add to callback queue and repeat # -- For last output, just add to callback queue - if has_async_callback: + if has_async_output_proc_callback: assert output_proc_callback is not None + ctx = output_proc_callback.keywords["ctx"] + # The boundary conditions to execute modification of multi-step: + # 1. only one step remain + # 2. no finished reqs + # 3. no waiting or swapped reqs + # 4. pythonization_cache has enough space for new step + # Here we limit the max scheduler_step by the PythonizationCache. + # While the initial size of cache is 128(PyObjectCache.__init__), + # once the output_size of multi-step is greater than 128,The growth of cache + # will be triggered, which doubles the cache_size. + # The cost of object allocations is larger than that of scheduler, + # thus we limit the multi-step by remaining size of PythonizationCache. + # TODO: Add udf PythonizationCache initial size + if (has_async_multi_step_modify_callback + and is_last_step + and not ctx.is_any_finished() + and not model_input.has_pending_reqs + and pythonization_cache.get_remain_cache() > ctx.get_running_reqs()): + if multi_step_modify_callback(): + # if the multi-step has been modified, + # just return as current step is no longer the last step + return outputs + # Invoke callback before pythonize (to overlap with GPU) output_proc_callback() @@ -433,8 +467,6 @@ def _final_process_outputs(self, model_input: StatefulModelInput, # For non last step, add to callback queue to chain # callbacks=>pythonize pairs (for GPU overlap) if not is_last_step: - ctx = output_proc_callback.keywords[ # type: ignore - "ctx"] # type: ignore ctx.append_output( outputs=[output.sampler_output], seq_group_metadata_list=ctx. @@ -584,7 +616,11 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: outputs = self._final_process_outputs( - model_input, model_input.base_output_proc_callback) + model_input, model_input.base_output_proc_callback, + model_input.frozen_model_input.multi_step_modify_callback, + self.pythonization_cache) + if not outputs: + return output if self.pythonization_cache: self.pythonization_cache.reset() return outputs diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 1f982fe103366..85b75b7c556ea 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -46,6 +46,7 @@ def _get_driver_input_and_broadcast( assert self.is_driver_worker virtual_engine = execute_model_req.virtual_engine is_first_multi_step = execute_model_req.is_first_multi_step + has_pending_reqs = execute_model_req.has_pending_reqs if is_first_multi_step: # on first step we prepare the worker input and model input normally worker_input: WorkerInput = self.prepare_worker_input( @@ -59,7 +60,8 @@ def _get_driver_input_and_broadcast( if execute_model_req.async_callback: model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, - async_callback=execute_model_req.async_callback) + async_callback=execute_model_req.async_callback, + multi_step_modify_callback=execute_model_req.multi_step_modify_callback) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine] @@ -75,6 +77,7 @@ def _get_driver_input_and_broadcast( model_input.is_first_multi_step = is_first_multi_step model_input.is_last_step = execute_model_req.is_last_step + model_input.has_pending_reqs = has_pending_reqs if not is_first_multi_step: # we broadcast the last sampled token ids to all TP workers so they