Skip to content

Commit

Permalink
[CORE] No Request No Scheduler: auto-increment of multi-step
Browse files Browse the repository at this point in the history
  • Loading branch information
DriverSong committed Dec 3, 2024
1 parent 395b1c7 commit 6c59886
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 15 deletions.
10 changes: 10 additions & 0 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device))

def has_enough_empty_slots(self, num_empty_slots: int) -> bool:
"""Return if the BlockTable has at least the specified number of
empty slots available.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
assert self._is_allocated

return self._num_empty_slots >= num_empty_slots

def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
Expand Down
27 changes: 26 additions & 1 deletion vllm/core/block_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from typing import Sequence as GenericSequence
from typing import Tuple

Expand Down Expand Up @@ -226,6 +226,27 @@ def can_append_slots(self, seq_group: SequenceGroup,
Device.GPU)
return num_touched_blocks <= num_free_gpu_blocks

def can_add_slots(self, seq_groups: Union[SequenceGroup, List[SequenceGroup]],

Check failure on line 229 in vllm/core/block_manager.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/block_manager.py:229:81: E501 Line too long (82 > 80)
num_lookahead_slots: int) -> bool:
""" If the empty slot is less then lookahead_slots, a new block should be allocated.

Check failure on line 231 in vllm/core/block_manager.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/block_manager.py:231:81: E501 Line too long (92 > 80)
Determine if there is enough space in GPU KV cache for the num_lookahead_slots.

Check failure on line 232 in vllm/core/block_manager.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/block_manager.py:232:81: E501 Line too long (87 > 80)
"""

if not isinstance(seq_groups,list):
seq_groups = [seq_groups]

num_touched_blocks = 0
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]

num_touched_blocks += 0 if (
block_table.has_enough_empty_slots(num_lookahead_slots)) else 1

Check failure on line 244 in vllm/core/block_manager.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/block_manager.py:244:81: E501 Line too long (83 > 80)

num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
return num_touched_blocks <= num_free_gpu_blocks

def append_slots(
self,
seq: Sequence,
Expand All @@ -243,6 +264,10 @@ def append_slots(
new_cows = self.block_allocator.clear_copy_on_writes()
return new_cows

def add_lookahead_slots(self, seq: Sequence, num_lookahead_slots: int):
block_table = self.block_tables[seq.seq_id]
block_table.ensure_num_empty_slots(num_lookahead_slots)

def free(self, seq: Sequence) -> None:
seq_id = seq.seq_id

Expand Down
10 changes: 9 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def is_empty(self) -> bool:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def is_any_finished(self) -> bool:
return (not self.scheduled_seq_groups or
any(sg.seq_group.is_finished()
for sg in self.scheduled_seq_groups))

def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
Expand Down Expand Up @@ -492,6 +497,9 @@ def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0

def has_pending_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.swapped) != 0

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

Expand Down Expand Up @@ -1231,7 +1239,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
num_lookahead_slots=num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
len(running_scheduled.swapped_out))
)

def _schedule(self) -> SchedulerOutputs:
Expand Down
60 changes: 58 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.block_manager import SelfAttnBlockSpaceManager
from vllm.core.placeholder_block_space_manager import PlaceholderBlockSpaceManager

Check failure on line 20 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:20:81: E501 Line too long (82 > 80)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -59,7 +61,7 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind, weak_bind_with_ret

Check failure on line 64 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:64:81: E501 Line too long (87 > 80)
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -133,6 +135,15 @@ def append_output(self, outputs: List[SamplerOutput],
is_first_step_output=is_first_step_output,
skip=[]))

def is_any_finished(self) -> bool:
return self.scheduler_outputs.is_any_finished()

Check failure on line 139 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "is_any_finished" [union-attr]

Check failure on line 139 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "is_any_finished" [union-attr]

Check failure on line 139 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "is_any_finished" [union-attr]

Check failure on line 139 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "is_any_finished" [union-attr]

def get_running_reqs(self) -> int:
running_seqs = 0
for scheduled_seq_group in self.scheduler_outputs.scheduled_seq_groups:

Check failure on line 143 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 143 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 143 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 143 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]
running_seqs += len(scheduled_seq_group.
seq_group.get_seqs(status=SequenceStatus.RUNNING))
return running_seqs

class LLMEngine:
"""An LLM engine that receives requests and generates texts.
Expand Down Expand Up @@ -417,6 +428,15 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

multi_step_modify = weak_bind_with_ret(self._verify_and_add_multi_step)

self.multi_step_modify_callback = [
partial(multi_step_modify,
block_manager=self.scheduler[v_id].block_manager,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]

# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
Expand Down Expand Up @@ -1437,11 +1457,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
last_sampled_token_ids=last_sampled_token_ids,
has_pending_reqs=self.scheduler[virtual_engine].has_pending_seqs())

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
execute_model_req.multi_step_modify_callback \
= self.multi_step_modify_callback[virtual_engine]

outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
Expand Down Expand Up @@ -2079,3 +2102,36 @@ def _build_logits_processors(
sampling_params.logits_processors.extend(logits_processors)

return sampling_params

def _verify_and_add_multi_step(self,
block_manager: Union[
SelfAttnBlockSpaceManager, PlaceholderBlockSpaceManager],

Check failure on line 2108 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:2108:81: E501 Line too long (96 > 80)
ctx: SchedulerContext) -> bool:
""" Callback for NRNS to determine if the multi-step can increase and
then execute the increment.
If the block is enough for the increment, add the lookahead_slots of each seq,

Check failure on line 2112 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:2112:81: E501 Line too long (86 > 80)
and also update the num_steps of each seq_group_metadata.
Return true when the increment succeeds,
"""
if len(ctx.output_queue) == 0:
return False
scheduler_step = ctx.seq_group_metadata_list[0].state.num_steps

Check failure on line 2118 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Value of type "Optional[list[Any]]" is not indexable [index]

Check failure on line 2118 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 2118 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 2118 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Value of type "list[Any] | None" is not indexable [index]
# if multi-step is 1, the multi-step is off, just return false.
if scheduler_step == 1:
return False
sequence_groups = [scheduled_seq_group.seq_group for scheduled_seq_group
in ctx.scheduler_outputs.scheduled_seq_groups]

Check failure on line 2123 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 2123 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 2123 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]

Check failure on line 2123 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "scheduled_seq_groups" [union-attr]

# For prefill + enable_chunked, num_lookahead_slots = scheduler_step,
# For Decode, num_lookahead_slots = scheduler_step - 1.
# Here consider the worst case, assume num_lookahead_slots = scheduler_step,

Check failure on line 2127 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:2127:81: E501 Line too long (84 > 80)
# Thus the new num_lookahead_slots should be num_lookahead_slots + 1
if block_manager.can_add_slots(seq_groups=sequence_groups,
num_lookahead_slots=scheduler_step+1):
for sequence_group in sequence_groups:
for seq in sequence_group.seqs:
block_manager.add_lookahead_slots(seq, scheduler_step+1)
for seq_group_metadata in ctx.seq_group_metadata_list:

Check failure on line 2134 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[list[Any]]" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 2134 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 2134 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]

Check failure on line 2134 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr]
seq_group_metadata.add_step()
return True
return False
14 changes: 14 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,9 @@ class SequenceGroupState(msgspec.Struct,
def remaining_steps(self) -> int:
return self.num_steps - self.current_step

def add_step(self):
self.num_steps = self.num_steps + 1


class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Expand Down Expand Up @@ -964,6 +967,10 @@ def is_single_step_prompt(self) -> bool:
# step.
return self.is_prompt and self.do_sample

def add_step(self):
if self.state is not None:
self.state.add_step()

def get_first_seq_id(self) -> int:
# This is an efficient way of fetching the seq_id when
# we know this SequenceGroup has only one sequence.
Expand Down Expand Up @@ -1265,6 +1272,13 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
multi_step_modify_callback: Optional[Callable[..., bool]] = None
has_pending_reqs: bool = True

def add_step(self):
if self.seq_group_metadata_list is not None:
for seq_group_metadata in self.seq_group_metadata_list:
seq_group_metadata.add_step()

@property
def is_first_multi_step(self) -> bool:
Expand Down
20 changes: 18 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ class PyObjectCache:
across scheduler iterations.
"""

def __init__(self, obj_builder):
def __init__(self, obj_builder, init_size: int = 128):
self._obj_builder = obj_builder
self._index = 0

self._obj_cache = []
for _ in range(128):
for _ in range(init_size):
self._obj_cache.append(self._obj_builder())

def _grow_cache(self):
Expand All @@ -313,6 +313,9 @@ def reset(self):
"""
self._index = 0

def get_remain_index(self) -> int:
return len(self._obj_cache) - self._index


@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
Expand Down Expand Up @@ -1161,6 +1164,19 @@ def weak_bound(*args, **kwargs) -> None:

return weak_bound

def weak_bind_with_ret(bound_method: Callable[..., Any], ) -> Callable[..., Any]:

Check failure on line 1167 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/utils.py:1167:81: E501 Line too long (81 > 80)
"""Make an instance method with return that weakly references
its associated instance and no-ops once that
instance is collected."""
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
unbound = bound_method.__func__ # type: ignore[attr-defined]

def weak_bound(*args, **kwargs) -> Any:
if inst := ref():
return unbound(inst, *args, **kwargs)

return weak_bound


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]:
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
async_callback: Optional[Callable] = None
multi_step_modify_callback: Optional[Callable[..., bool]] = None,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None

Expand Down
52 changes: 44 additions & 8 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from ..config import CacheConfig

from ..model_executor.model_loader.tensorizer import TensorizerConfig

Expand Down Expand Up @@ -54,14 +55,20 @@ def completion_seq_group_output_builder():
class PythonizationCache:

def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
# I set the initila size here manually by the size of gpu block(21494)
# TODO: set the initial_size by the cache_config.num_gpu_blocks
self.cached_seq_output = PyObjectCache(seq_output_builder, 22000 * 16)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
completion_seq_group_output_builder, 20000 * 16)

def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()

def get_remain_cache(self) -> int:
return min(self.cached_seq_output.get_remain_index(),
self.cached_completion_seq_group_output.get_remain_index())


@dataclass
class ModelOutput:
Expand Down Expand Up @@ -159,6 +166,7 @@ class StatefulModelInput(BroadcastableModelInput):
num_seqs: int = -1
num_queries: int = -1
num_single_step_prefills: int = 0
has_pending_reqs: bool = True

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
assert self.frozen_model_input is not None
Expand Down Expand Up @@ -405,10 +413,13 @@ def _async_process_outputs(self, model_input: StatefulModelInput,
break

def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
output_proc_callback: Optional[Callable],
multi_step_modify_callback: Optional[Callable[..., bool]],
pythonization_cache: Optional[PythonizationCache] = None):
assert model_input.frozen_model_input is not None

has_async_callback = output_proc_callback is not None
has_async_output_proc_callback = output_proc_callback is not None
has_async_multi_step_modify_callback = multi_step_modify_callback is not None

outputs = []
for step_num, output in enumerate(model_input.cached_outputs):
Expand All @@ -419,9 +430,32 @@ def _final_process_outputs(self, model_input: StatefulModelInput,
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
if has_async_output_proc_callback:
assert output_proc_callback is not None

ctx = output_proc_callback.keywords["ctx"]
# The boundary conditions to execute modification of multi-step:
# 1. only one step remain
# 2. no finished reqs
# 3. no waiting or swapped reqs
# 4. pythonization_cache has enough space for new step
# Here we limit the max scheduler_step by the PythonizationCache.
# While the initial size of cache is 128(PyObjectCache.__init__),
# once the output_size of multi-step is greater than 128,The growth of cache
# will be triggered, which doubles the cache_size.
# The cost of object allocations is larger than that of scheduler,
# thus we limit the multi-step by remaining size of PythonizationCache.
# TODO: Add udf PythonizationCache initial size
if (has_async_multi_step_modify_callback
and is_last_step
and not ctx.is_any_finished()
and not model_input.has_pending_reqs
and pythonization_cache.get_remain_cache() > ctx.get_running_reqs()):
if multi_step_modify_callback():
# if the multi-step has been modified,
# just return as current step is no longer the last step
return outputs

# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()

Expand All @@ -433,8 +467,6 @@ def _final_process_outputs(self, model_input: StatefulModelInput,
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
ctx.append_output(
outputs=[output.sampler_output],
seq_group_metadata_list=ctx.
Expand Down Expand Up @@ -584,7 +616,11 @@ def execute_model(
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = self._final_process_outputs(
model_input, model_input.base_output_proc_callback)
model_input, model_input.base_output_proc_callback,
model_input.frozen_model_input.multi_step_modify_callback,
self.pythonization_cache)
if not outputs:
return output
if self.pythonization_cache:
self.pythonization_cache.reset()
return outputs
Expand Down
Loading

0 comments on commit 6c59886

Please sign in to comment.