diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 1f3f4710735e5..195eb27dee749 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel( slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index a75a671e57f74..615549f2134ad 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -37,6 +37,7 @@ @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("is_async", [True]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -49,6 +50,7 @@ async def test_multi_step( is_async: bool, num_logprobs: Optional[int], attention_backend: str, + enable_chunked_prefill: bool, monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol @@ -74,6 +76,10 @@ async def test_multi_step( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> no logprobs """ + if enable_chunked_prefill and \ + (pp_size > 1 or attention_backend != "FLASH_ATTN"): + pytest.skip("Multi-step with Chunked-Prefill only supports" + "PP=1 and FLASH_ATTN backend") override_backend_env_variable(monkeypatch, attention_backend) @@ -93,6 +99,9 @@ async def test_multi_step( if eager_mode: ms_server_args.append("--enforce-eager") + if enable_chunked_prefill: + ms_server_args.append("--enable-chunked-prefill") + distributed_args = [ "--tensor-parallel-size", str(tp_size), diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index c5dc81cc25622..ff413e8e2da3f 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -16,6 +16,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @@ -28,6 +29,7 @@ def test_multi_step_llm( model: str, dtype: str, tp_size: int, + enable_chunked_prefill: bool, max_tokens: int, enforce_eager: int, num_scheduler_steps: int, @@ -51,6 +53,7 @@ def test_multi_step_llm( model: model under test (same for single- and multi-step engines) dtype: tensor datatype for engine to utilize tp_size: degree of tensor-parallelism + enable_chunked_prefill: chunked-prefill on/off max_tokens: the maximum number of tokens to generate enforce_eager num_scheduler_steps: for multi-step scheduling, GPU-side steps per @@ -73,6 +76,7 @@ def test_multi_step_llm( gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, use_v2_block_manager=True, + enable_chunked_prefill=enable_chunked_prefill, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 22d07c0a4f689..43ca6c9ff160e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ @@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert num_seqs > num_queries assert self.use_cuda_graph + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs @@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) @@ -706,8 +726,10 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 784cff0d9878e..a64bf34596f99 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self - def advance_step( - self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - ): + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with flashinfer yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + assert num_seqs > 0 assert num_queries > 0 assert model_input.attn_metadata is not None diff --git a/vllm/config.py b/vllm/config.py index 108badf150c86..3139c5a08bfb8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -983,9 +983,16 @@ def __init__(self, policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + if num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + max_num_batched_tokens = max(max_model_len, 2048) + else: + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index c002dd1397f96..a9f4bd871dfda 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -55,9 +55,12 @@ def __init__( self._num_full_slots = self._get_num_token_ids() @staticmethod - def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: """Calculates the minimum number of blocks required to store a given - sequence of token IDs. + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). This assumes worst-case scenario, where every block requires a new allocation (e.g. ignoring prefix caching). @@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: token_ids (List[int]): The sequence of token IDs to be stored. block_size (int): The maximum number of tokens that can be stored in a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. Returns: int: The minimum number of blocks required to store the given - sequence of token IDs. + sequence of token IDs along with any required look-ahead slots. """ - return cdiv(len(token_ids), block_size) + return cdiv(len(token_ids) + num_lookahead_slots, block_size) def allocate(self, token_ids: List[int], diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 24ab9eb66194d..a1f96707a6b54 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -281,10 +281,15 @@ def __init__( def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: return 0 if seq is None else seq.n_blocks - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 54818c7e3e9a6..bb78b1e1c9138 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -107,7 +107,9 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, ) if seq_group.is_encoder_decoder(): diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index c47d7d8dfb075..476e043ecc52d 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -21,7 +21,9 @@ def __init__( ) -> None: pass - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # Always return OK for dummy purposes return AllocStatus.OK diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f4..6346711587301 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str): raise ValueError(f"Unknown version {version=}") @abstractmethod - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 873decff37c1e..5b7587d150843 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -522,7 +522,7 @@ def _schedule_running( ret.swapped_out.clear() ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False) + is_prefill=False, enable_chunking=enable_chunking) ret.decode_seq_groups_list.clear() ret.prefill_seq_groups_list.clear() @@ -561,7 +561,7 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group): + while not self._can_append_slots(seq_group, enable_chunking): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() @@ -611,7 +611,7 @@ def _schedule_running( if not cont_loop: break else: - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ @@ -684,7 +684,8 @@ def _schedule_swapped( # If the sequence group cannot be swapped in, stop. is_prefill = seq_group.is_prefill() alloc_status = self.block_manager.can_swap_in( - seq_group, self._get_num_lookahead_slots(is_prefill)) + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -727,7 +728,7 @@ def _schedule_swapped( curr_loras.add(lora_int_id) swapped_queue.popleft() self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( @@ -747,12 +748,13 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False), + is_prefill=False, enable_chunking=enable_chunking), infeasible_seq_groups=infeasible_seq_groups, ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: + if self.scheduler_config.chunked_prefill_enabled and \ + not self.scheduler_config.is_multi_step: prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min(self.scheduler_config.max_model_len, @@ -899,15 +901,21 @@ def _schedule_prefills( waiting_queue.popleft() continue + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds the capacity of block_manager", - num_new_tokens) + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, num_lookahead_slots) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -939,9 +947,24 @@ def _schedule_prefills( curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) - seq_group.init_multi_step( - num_scheduler_steps=self._get_num_lookahead_slots( - is_prefill=True) + 1) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -956,7 +979,8 @@ def _schedule_prefills( return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. @@ -1153,7 +1177,8 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ @@ -1164,13 +1189,17 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: self.artificial_preempt_cnt -= 1 return False - # Appending slots only occurs in decoding. - is_prefill = False + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: no_beam_search = seq_group.sampling_params is None or ( @@ -1186,7 +1215,7 @@ def schedule( # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() - scheduler_outputs = self._schedule() + scheduler_outputs: SchedulerOutputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1383,11 +1412,10 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - ) -> None: + def _append_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1398,11 +1426,25 @@ def _append_slots( int is the destination block index. This list is updated with the new source and destination block indices for the appended slots. + enable_chunking (bool): True if chunked prefill is enabled. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) - seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): cows = self.block_manager.append_slots(seq, num_lookahead_slots) if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1513,16 +1555,32 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. Speculative decoding does not yet support prefill, so we do not perform lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. """ if is_prefill: - return 0 + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 return self.scheduler_config.num_lookahead_slots @@ -1565,6 +1623,16 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size + elif self.scheduler_config.is_multi_step: + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d4559e377427..0efb0cbbf8bec 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -980,9 +980,13 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill: - raise ValueError("Chunked prefill is not supported with " - "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.enable_prefix_caching: + raise ValueError("Multi-Step is not supported with " + "both Chunked-Prefill and Prefix-Caching " + "enabled together.") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") # make sure num_lookahead_slots is set the higher value depending on # if we are using speculative decoding or multi-step diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 54c5af2fe3665..3361fdefc960c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -363,11 +363,18 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 487255cb6b595..19f88ac3e7c5d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -90,6 +90,12 @@ class OutputData(NamedTuple): scheduler_outputs: SchedulerOutputs is_async: bool is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] skip: List[int] @@ -108,13 +114,15 @@ def __init__(self, multi_step_stream_outputs: bool = False): def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool): + is_last_step: bool, + is_first_step_output: Optional[bool]): self.output_queue.append( OutputData(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=is_async, is_last_step=is_last_step, + is_first_step_output=is_first_step_output, skip=[])) @@ -237,9 +245,10 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " - "enable_prefix_caching=%s, use_async_output_proc=%s, " - "use_cached_outputs=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -270,6 +279,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, @@ -957,8 +967,66 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - """ + + def update_prefill_num_computed_tokens( + seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, num_outputs: int, + is_first_step_output: Optional[bool]) -> None: + """ + When multi-step and chunked-prefill are enabled together, the + prefill sequence scheduled for multi-step execution turn into + decodes in the first step itself. This function accounts + for that conversion. + + seq_group: SequenceGroup - A prefill seq_group + seq_group_meta: SequenceGroupMetadata - Metadata of the given + prefill seq_group + num_outputs: int - number of output tokens being processed for the + given seq_group + is_first_step_output: Optional[bool] - + If multi-step is enabled and num_outputs is 1, this value + indicates if this outputs belongs to the first step in the + multi-step. + If multi-step is enabled and num_outputs > 1, this value + must be None, as num_outputs > 1 indicates that outputs from + all the steps in multi-step are submitted in a single burst. + When multi-step is disabled, this value is always True. + """ + + assert seq_group_meta.is_prompt + + token_chunk_size = seq_group_meta.token_chunk_size + + if num_outputs == 1: + assert is_first_step_output is not None + + if seq_group_meta.state.num_steps == 1: + assert is_first_step_output is True + seq_group.update_num_computed_tokens(token_chunk_size) + return + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + if is_first_step_output is True: + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + return + + assert is_first_step_output is None + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill. Outputs from all the steps are + # submitted in a single burst. + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + assert num_outputs == seq_group_meta.state.num_steps, \ + f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + now = time.time() if len(ctx.output_queue) == 0: @@ -969,20 +1037,27 @@ def _process_model_outputs(self, # When we process only one request, no pop is required # (since later we will process all of the rest) (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue[0] + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] else: (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue.popleft() + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( scheduler_outputs.scheduled_seq_groups) - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - if len(outputs) > 1: + has_multiple_outputs: bool = len(outputs) > 1 + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. outputs_by_sequence_group = create_output_by_sequence_group( outputs, num_seq_groups=len(seq_group_metadata_list)) + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None else: outputs_by_sequence_group = outputs @@ -1018,14 +1093,17 @@ def _process_model_outputs(self, finished_before.append(i) continue - if len(outputs) > 1: + if has_multiple_outputs: output = outputs_by_sequence_group[i] else: output = [outputs_by_sequence_group[0][i]] - if not is_async: - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) + if not is_async and seq_group_meta.is_prompt: + # Updates for all decodes happen when we actually append the + # token ids to the seq in process_outputs. + update_prefill_num_computed_tokens(seq_group, seq_group_meta, + len(output), + is_first_step_output) if outputs: for o in outputs: @@ -1159,8 +1237,18 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + if seq_group_metadata.is_prompt: + if self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled: + # Prompts are scheduled in multi-step only when + # chunking is enabled. These prompts turn into + # decodes after the very first step. Therefore, + # we skip the update to the num_computed_tokens + # here. + pass + else: + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( @@ -1172,6 +1260,7 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) + seq_group.update_num_computed_tokens(1) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1324,12 +1413,19 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + # Add results to the output_queue ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len(outputs) == 1, ( diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 31c2bbc8e7127..cd5cfe5485f21 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -170,6 +170,7 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) + seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) diff --git a/vllm/sequence.py b/vllm/sequence.py index 49a198df045bd..781bcedde2b52 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -743,10 +743,35 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 - def init_multi_step(self, num_scheduler_steps: int) -> None: - self.state.num_steps = num_scheduler_steps + def init_multi_step(self, num_steps: int) -> None: + self.state.num_steps = num_steps self.state.current_step = 0 + def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool) -> None: + + if not is_multi_step: + self.init_multi_step(num_steps=num_scheduler_steps) + return + + # Multi-Step case + is_prefill = self.is_prefill() + + # The asserts below reflect the expectations of the current system. + if is_prefill and enable_chunking: + assert num_lookahead_slots == num_scheduler_steps + self.init_multi_step(num_steps=num_lookahead_slots) + else: + is_decode: bool = not is_prefill + # If it is a prefill, num_lookahead_slots must be 0 + assert num_lookahead_slots == 0 or is_decode + # If it is a decode, num_lookahead_slots + 1 must match + # the scheduler steps. + assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill + self.init_multi_step(num_steps=num_lookahead_slots + 1) + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -1010,6 +1035,20 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 + # Multi-Step Chunked-Prefill property + @property + def is_single_step_prompt(self) -> bool: + # do_sample is true, only when the token_chunk_size matches the + # num_uncomputed_tokens of the sequence. This indicates that + # the prompt will finish processing in a single `execute_model` + # step. + return self.is_prompt and self.do_sample + + def get_first_seq_id(self) -> int: + # This is an efficient way of fetching the seq_id when + # we know this SequenceGroup has only one sequence. + return next(iter(self.seq_data)) + def apply_delta(self, sequence_group_metadata_delta: SequenceGroupMetadataDelta): for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): @@ -1022,7 +1061,8 @@ def apply_delta(self, def finish_step(self) -> None: assert self.state is not None - assert self.state.current_step < self.state.num_steps + assert self.state.current_step < self.state.num_steps, \ + f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa self.state.current_step += 1 diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index c7295f872f70f..4c57a37c87870 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -14,7 +14,7 @@ get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import PyObjectCache +from vllm.utils import PyObjectCache, async_tensor_h2d from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -30,6 +30,14 @@ logger = init_logger(__name__) MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] + +def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ + -> List[str]: + if chunked_prefill_enabled: + return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS + else: + return MULTI_STEP_ATTENTION_BACKENDS def seq_output_builder(): @@ -144,11 +152,13 @@ class StatefulModelInput(BroadcastableModelInput): is_multi_step: bool = True is_last_step: bool = False is_first_multi_step: bool = False + base_output_proc_callback: Optional[Callable] = None # ping-pong data structures for multi-step to wait on the previous step step_cuda_events: List[torch.cuda.Event] = field( default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) num_seqs: int = -1 num_queries: int = -1 + num_single_step_prefills: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None @@ -161,6 +171,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: 'is_first_multi_step': self.is_first_multi_step, 'num_seqs': self.num_seqs, 'num_queries': self.num_queries, + 'num_single_step_prefills': self.num_single_step_prefills, } tensor_dict.update(new_tensor_dict) return tensor_dict @@ -209,6 +220,81 @@ def add_sampler_output(self, sampled_token_ids=sampled_token_ids, pythonized=False)) + def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): + """ + sampling_metadata.selected_token_indices is constructed for the + first-step in Multi-Step. However, when chunked-prefill is enabled with + multi-step, the scheduled prompts are fully processed in the + first-step and are processed as decodes in the rest of the steps. + This function updates the sampling_metadata.selected_token_indices + to account for this conversion. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 respectively. + + In that case, sampling_metadata.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + """ + assert self.current_step == 1 and self.num_single_step_prefills > 0 + if not get_pp_group().is_last_rank: + return + + assert self.frozen_model_input is not None + assert self.frozen_model_input.sampling_metadata is not None + self.frozen_model_input.sampling_metadata.selected_token_indices = \ + async_tensor_h2d(list(range(self.num_queries)), + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + assert fmi.lora_mapping is None + assert fmi.lora_requests is not None + assert len(fmi.lora_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.prompt_adapter_mapping is None + assert fmi.prompt_adapter_requests is not None + assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + self.maybe_advance_sampling_metadata(device, pin_memory) + # MutableModelInputForGPUWithMultiStepMetadata is not subclass of # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step @@ -220,6 +306,19 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): super().__init__(*args, **kwargs) + # Check attention backend support. + supported_attention_backends: List[str] = \ + _get_supported_attention_backends( + self.scheduler_config.chunked_prefill_enabled) + if self.attn_backend.get_name() not in supported_attention_backends: + ms_config_str: str = "Multi-Step + Chunked-Prefill" \ + if self.scheduler_config.chunked_prefill_enabled \ + else "Multi-Step" + raise ValueError( + f"{ms_config_str} not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {supported_attention_backends}.") + # uses the base model runner to execute the model and wraps it with # multi-step logic self._base_model_runner: GPUModelRunnerBase = base_model_runner @@ -248,14 +347,25 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> StatefulModelInput: - frozen_model_input = self._base_model_runner.prepare_model_input( - seq_group_metadata_list, virtual_engine, finished_requests_ids) + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills model_input = StatefulModelInput( frozen_model_input=frozen_model_input, - num_seqs=len(frozen_model_input.seq_lens), - num_queries=len(frozen_model_input.query_lens), - ) + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills) + return model_input def _async_process_outputs(self, model_input: StatefulModelInput, @@ -265,7 +375,7 @@ def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback() cont = True - for model_output in model_input.cached_outputs: + for step_num, model_output in enumerate(model_input.cached_outputs): if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) @@ -276,7 +386,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) output_proc_callback() else: @@ -292,9 +403,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, has_async_callback = output_proc_callback is not None outputs = [] - for output_id in range(len(model_input.cached_outputs)): - output = model_input.cached_outputs[output_id] - is_last_step = output_id == len(model_input.cached_outputs) - 1 + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 # For non-async case: # -- We simply add the outputs @@ -323,7 +433,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) else: outputs.append(output.sampler_output) else: @@ -389,18 +500,27 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) - output_proc_callback = None + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + if frozen_model_input.async_callback is not None: - output_proc_callback = frozen_model_input.async_callback - assert output_proc_callback is not None + assert model_input.base_output_proc_callback is not None async_callback = functools.partial( self._async_process_outputs, model_input=model_input, - output_proc_callback=output_proc_callback) + output_proc_callback=model_input.base_output_proc_callback) - frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None # Execute the model @@ -455,8 +575,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = self._final_process_outputs(model_input, - output_proc_callback) + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) self.pythonization_cache.reset() return outputs @@ -484,11 +604,14 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS: - raise ValueError( - f"Multi-step not supported for attention backend: " - f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " - f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.") + + model_input.maybe_advance_frozen_model_input(self.device, + self.pin_memory) + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids num_seqs = model_input.num_seqs @@ -498,13 +621,15 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert attn_metadata is not None + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 attn_metadata.advance_step( frozen_model_input, sampled_token_ids, self.block_size, num_seqs, num_queries, - ) + turn_prefills_into_decodes=turn_prefills_into_decodes) return model_input diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 562285f828cc7..bf66f32d7d244 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -76,8 +76,9 @@ def _get_driver_input_and_broadcast( frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.attn_metadata is not None - # clear the cached decode metadata so that it can be recomputed on - # the workers + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None frozen_model_input.attn_metadata._cached_decode_metadata = None model_input.is_first_multi_step = is_first_multi_step