From afb050b29d0cac27c32c19c8206a9ac2a4662de2 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 2 Oct 2024 15:44:39 -0400 Subject: [PATCH] [Core] CUDA Graphs for Multi-Step + Chunked-Prefill (#8645) Co-authored-by: Varun Sundar Rabindranath --- csrc/prepare_inputs/advance_step.cu | 11 ++++ vllm/attention/backends/flash_attn.py | 48 ++++++++++-------- vllm/worker/model_runner.py | 72 +++++++++++++++++++++------ 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 195eb27dee749..46fef79f439fb 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel( long const* sampled_token_ids_ptr, long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int64_t const block_tables_stride) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } + int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e277023367195..bb8ab1e3c8c26 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -500,6 +500,30 @@ def _add_seq_group( seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. @@ -533,29 +557,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device=device, non_blocking=True) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 95739f82552a4..f44e5113c218d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -712,14 +712,62 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def _use_captured_graph(self, batch_size: int, + decode_only: bool, max_decode_seq_len: int, max_encoder_seq_len: int = 0) -> bool: - return (self.decode_only and not self.runner.model_config.enforce_eager + return (decode_only and not self.runner.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) + def _get_cuda_graph_pad_size(self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> int: + """ + Determine the number of padding sequences required for running in + CUDA graph mode. Returns -1 if CUDA graphs cannot be used. + + In the multi-step + chunked-prefill case, only the first step + has Prefills (if any). The rest of the steps are guaranteed to be all + decodes. In this case, we set up the padding as if all the sequences + are decodes so we may run all steps except the first step in CUDA graph + mode. The padding is accounted for in the multi-step `advance_step` + family of functions. + + Args: + num_seqs (int): Number of sequences scheduled to run. + max_decode_seq_len (int): Greatest of all the decode sequence + lengths. Used only in checking the viablility of using + CUDA graphs. + max_encoder_seq_len (int, optional): Greatest of all the encode + sequence lengths. Defaults to 0. Used only in checking the + viability of using CUDA graphs. + Returns: + int: Returns the determined number of padding sequences. If + CUDA graphs is not viable, returns -1. + """ + is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ + self.runner.scheduler_config.chunked_prefill_enabled + decode_only = self.decode_only or is_mscp + if not decode_only: + # Early exit so we can treat num_seqs as the batch_size below. + return -1 + + # batch_size out of this function refers to the number of input + # tokens being scheduled. This conflation of num_seqs as batch_size + # is valid as this is a decode-only case. + batch_size = num_seqs + if not self._use_captured_graph(batch_size, decode_only, + max_decode_seq_len, + max_encoder_seq_len): + return -1 + + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + return graph_batch_size - batch_size + def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. @@ -778,21 +826,17 @@ def build(self) -> ModelInputForGPU: for data in self.inter_data_list } - batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph( - batch_size, - max_decode_seq_len, + cuda_graph_pad_size = self._get_cuda_graph_pad_size( + num_seqs=len(seq_lens), + max_decode_seq_len=max_encoder_seq_len, max_encoder_seq_len=max_encoder_seq_len) - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - cuda_graph_pad_size = -1 - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - cuda_graph_pad_size = graph_batch_size - batch_size - batch_size = graph_batch_size + batch_size = len(input_tokens) + if cuda_graph_pad_size != -1: + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + batch_size += cuda_graph_pad_size # Tokens and positions. if cuda_graph_pad_size: