diff --git a/tests/kernels/test_cascade.py b/tests/kernels/test_cascade.py new file mode 100644 index 0000000000000..742557bbef7d1 --- /dev/null +++ b/tests/kernels/test_cascade.py @@ -0,0 +1,380 @@ +from typing import List, Tuple + +import flashinfer +import pytest +import torch + +from vllm.utils import seed_everything + + +@pytest.mark.parametrize("beam_width", [16, 32]) +@pytest.mark.parametrize("seq_lens", [[(4096, 4096)]]) +@pytest.mark.parametrize("num_heads", [(16, 16)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_runs", [200]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [512]) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode() +def test_cascade_speedup(beam_width, seq_lens, num_heads, head_size, dtype, + block_size, num_runs, max_num_blocks_per_seq, + soft_cap): + """ + Compares the performance of flashinfer multilevel kernel and batch decode. + """ + + cascade_outputs, time_taken_cascade = run_multilevel_cascade_attention_wrapper( # noqa: E501 + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + seq_lens=seq_lens, + num_runs=num_runs, + block_size=block_size, + beam_width=beam_width, + num_levels=2, + max_num_blocks_per_seq=max_num_blocks_per_seq, + soft_cap=soft_cap, + ) + + batchdecode_outputs, time_taken_batchdecode = run_flashinfer_batchdecode_beam_search( # noqa: E501 + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + seq_lens=seq_lens, + num_runs=num_runs, + block_size=block_size, + beam_width=beam_width, + max_num_blocks_per_seq=max_num_blocks_per_seq, + soft_cap=soft_cap, + ) + + assert len(cascade_outputs) == len( + batchdecode_outputs), "Output length mismatch between the two methods." + + max_diff = 0 + + for cascade_output, batchdecode_output in zip(cascade_outputs, + batchdecode_outputs): + assert cascade_output.shape == batchdecode_output.shape, "Shape mismatch between outputs." # noqa: E501 + + isclose = torch.isclose(cascade_output, + batchdecode_output, + rtol=1e-2, + atol=1e-3) + if not isclose.all(): + diff = torch.abs(cascade_output - batchdecode_output) + current_max_diff = torch.max(diff).item() + max_diff = max(max_diff, current_max_diff) + + speedup = time_taken_batchdecode / time_taken_cascade + + assert speedup > 1.0, f"No speedup with cascade infer: {speedup}" + assert max_diff <= 1e-3, f"Max difference too large: {max_diff}" + + +def run_flashinfer_batchdecode_beam_search( + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + soft_cap: float, + seq_lens: list, + num_runs: int, + block_size: int, + beam_width: int, + max_num_blocks_per_seq: int, +) -> Tuple[List[torch.Tensor], float]: + torch.set_default_device("cuda") + seed_everything(0) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + + num_seqs = len(seq_lens) + kv_lens = [x[1] for x in seq_lens] + + num_blocks = max_num_blocks_per_seq * num_seqs * beam_width + + key_value_cache = torch.randn(num_blocks, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device='cuda').reshape( + num_blocks, 2, block_size, num_kv_heads, + head_size) + + workspace_size = 128 * 1024 * 1024 + workspace_buffer_decode = torch.empty(workspace_size, + dtype=torch.int8, + device='cuda') + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer_decode, "NHD") + + block_tables = torch.zeros((num_seqs * beam_width, max_num_blocks_per_seq), + dtype=torch.int32) + + block_offset = 0 + + for start_seq in range(num_seqs): + shared_len = kv_lens[start_seq] // block_size + + for i in range(start_seq * beam_width, (start_seq + 1) * beam_width): + block_tables[i, :shared_len] = torch.arange( + block_offset, block_offset + shared_len) + + block_offset += shared_len + + for i in range(beam_width): + beam_index = start_seq * beam_width + i + unique_start = block_offset + i + block_tables[beam_index, + shared_len:max_num_blocks_per_seq] = torch.arange( + unique_start, unique_start + + (max_num_blocks_per_seq - shared_len) * + beam_width, beam_width) + block_offset += (max_num_blocks_per_seq - shared_len) * beam_width + + cumulative_run_time = 0.0 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + outputs = [] + + # index of the next block that we append for each sequence + next_block_index = [num // block_size + 1 for num in kv_lens] + + kv_indptr: List[int] = [0] + kv_indices: List[List[int]] = [] + kv_last_page_lens: List[int] = [] + + for i in range(num_seqs * beam_width): + seq_len = kv_lens[i // beam_width] + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.append(list(block_tables[i, :num_blocks])) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + kv_indptr.append(kv_indptr[-1] + num_blocks) + + for step in range(num_runs): + torch.manual_seed(step) + + query = torch.randn( + num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, num_query_heads, + head_size) + + kv_indptr_tensor = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices_tensor = torch.cat([torch.tensor(kv) + for kv in kv_indices]).reshape(-1) + kv_last_page_lens_tensor = torch.tensor(kv_last_page_lens, + dtype=torch.int32) + + decode_wrapper.plan(kv_indptr_tensor, + kv_indices_tensor, + kv_last_page_lens_tensor, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype, + logits_soft_cap=soft_cap) + + start_event.record() + output = decode_wrapper.run(query, key_value_cache) + end_event.record() + torch.cuda.synchronize() + decode_time = start_event.elapsed_time(end_event) + cumulative_run_time += decode_time + + outputs.append(output.cpu()) + + if step % block_size == 0: + for i in range(beam_width * num_seqs): + kv_indices[i].append( + block_tables[i, next_block_index[i // beam_width]]) + + for i in range(len(next_block_index)): + next_block_index[i] += 1 + + for i in range(1, beam_width * num_seqs + 1): + kv_indptr[i] += i + kv_last_page_lens = [(prev + 1) % block_size or block_size + for prev in kv_last_page_lens] + + return outputs, cumulative_run_time + + +def run_multilevel_cascade_attention_wrapper( + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + seq_lens: list, + num_runs: int, + block_size: int, + beam_width: int, + num_levels: int, + max_num_blocks_per_seq: int, + soft_cap: float, +) -> Tuple[List[torch.Tensor], float]: + torch.set_default_device("cuda") + seed_everything(0) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + + num_seqs = len(seq_lens) + kv_lens = [x[1] for x in seq_lens] + + num_blocks = max_num_blocks_per_seq * num_seqs * beam_width + + key_value_cache = torch.randn(num_blocks, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device='cuda') + + workspace_size = 128 * 1024 * 1024 + workspace_buffer = torch.empty(workspace_size, + dtype=torch.uint8, + device='cuda') + wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( + num_levels, workspace_buffer, "NHD") + + block_tables = torch.zeros((num_seqs * beam_width, max_num_blocks_per_seq), + dtype=torch.int32) + block_offset = 0 + + for start_seq in range(num_seqs): + shared_len = kv_lens[start_seq] // block_size + + for i in range(start_seq * beam_width, (start_seq + 1) * beam_width): + block_tables[i, :shared_len] = torch.arange( + block_offset, block_offset + shared_len) + + block_offset += shared_len + + for i in range(beam_width): + beam_index = start_seq * beam_width + i + unique_start = block_offset + i + block_tables[beam_index, + shared_len:max_num_blocks_per_seq] = torch.arange( + unique_start, unique_start + + (max_num_blocks_per_seq - shared_len) * + beam_width, beam_width) + block_offset += (max_num_blocks_per_seq - shared_len) * beam_width + + qo_indptr_arr = [ + torch.tensor([0, beam_width * num_seqs], + dtype=torch.int32, + device='cuda'), + torch.arange(beam_width * num_seqs + 1, + dtype=torch.int32, + device="cuda") + ] + + shared_kv_page_indptr: List[int] = [0] + unique_kv_page_indptr: List[int] = [0] + shared_kv_page_indices: List[List[int]] = [] + unique_kv_page_indices: List[List[int]] = [] + shared_kv_last_page_len: List[int] = [] + unique_kv_last_page_len: List[int] = [] + + query = torch.arange(num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, + num_query_heads, head_size) + + # Fill the shared metadatas + for i in range(num_seqs): + seq_len = kv_lens[i // beam_width] + num_shared_blocks = ( + seq_len) // block_size if seq_len % block_size == 0 else ( + (seq_len) // block_size) + 1 + shared_kv_page_indices.append(list( + block_tables[i, :num_shared_blocks])) + shared_kv_page_indptr.append(shared_kv_page_indptr[-1] + + num_shared_blocks) + shared_kv_len = seq_len % block_size + if shared_kv_len == 0: + shared_kv_len = block_size + shared_kv_last_page_len.append(shared_kv_len) + + for i in range(num_seqs * beam_width): + unique_kv_page_indices.append([]) + unique_kv_page_indptr.append(unique_kv_page_indptr[-1]) + unique_kv_last_page_len.append(block_size) + + shared_kv_page_indptr = torch.tensor(shared_kv_page_indptr, + dtype=torch.int32, + device='cuda') + shared_kv_page_indices = torch.cat( + [torch.tensor(x) for x in shared_kv_page_indices]).reshape(-1) + shared_kv_last_page_len = torch.tensor(shared_kv_last_page_len, + dtype=torch.int32, + device='cuda') + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + cumulative_run_time = 0.0 + + outputs = [] + + # index of the next block that we append for each sequence + next_block_index = [num // block_size + 1 for num in kv_lens] + + for step in range(num_runs): + torch.manual_seed(step) + query = torch.randn( + num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, num_query_heads, + head_size) + + wrapper.plan(qo_indptr_arr, [ + shared_kv_page_indptr, + torch.tensor( + unique_kv_page_indptr, dtype=torch.int32, device='cuda') + ], [ + shared_kv_page_indices, + torch.cat([torch.tensor(x) + for x in unique_kv_page_indices]).reshape(-1) + ], [ + shared_kv_last_page_len, + torch.tensor( + unique_kv_last_page_len, dtype=torch.int32, device='cuda') + ], + num_query_heads, + num_kv_heads, + head_size, + block_size, + logits_soft_cap=soft_cap) + + start_event.record() + output = wrapper.run(query, key_value_cache) + end_event.record() + torch.cuda.synchronize() + + cumulative_run_time += start_event.elapsed_time(end_event) + + outputs.append(output.cpu()) + + if step % block_size == 0: + for i in range(beam_width * num_seqs): + unique_kv_page_indices[i].append( + block_tables[i, next_block_index[i // beam_width]]) + for i in range(len(next_block_index)): + next_block_index[i] += 1 + for i in range(1, beam_width * num_seqs + 1): + unique_kv_page_indptr[i] += i + + unique_kv_last_page_len = [(x + 1) % block_size or block_size + for x in unique_kv_last_page_len] + + return outputs, cumulative_run_time \ No newline at end of file diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..0f6b8cdb30042 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -5,6 +5,7 @@ Tuple, Type, TypeVar) import torch +from torch import nn from vllm.multimodal import MultiModalPlaceholderMap @@ -200,7 +201,8 @@ def prepare_graph_input_buffers( ... @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + def begin_forward(self, model_input: "ModelRunnerInputBase", + model: nn.Module) -> None: """Prepare state for forward pass.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..a978d5dc4d5f2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,19 +3,18 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.cascade import MultiLevelCascadeAttentionWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None + MultiLevelCascadeAttentionWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -33,6 +32,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) +logger = init_logger(__name__) + if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) @@ -104,8 +105,8 @@ def __init__(self, runner): self.runner = runner self._is_graph_capturing = False self._workspace_buffer = None - self._decode_wrapper = None - self._prefill_wrapper = None + self._cuda_wrapper = None + self._wrapper = None def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -115,25 +116,16 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") - return self._prefill_wrapper - - def _get_decode_wrapper(self): - if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._get_workspace_buffer(), - "NHD", - use_tensor_cores=use_tensor_cores) - return self._decode_wrapper + def _get_wrapper(self): + if self._wrapper is None: + self._wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._wrapper + + def _get_cuda_wrapper(self): + if self._cuda_wrapper is not None: + return self._cuda_wrapper + return None @contextmanager def graph_capture(self, max_batch_size: int): @@ -173,8 +165,8 @@ def graph_clone(self, batch_size: int): assert self._is_graph_capturing state = self.__class__(self.runner) state._workspace_buffer = self._graph_decode_workspace_buffer - state._decode_wrapper = self._graph_decode_wrapper - state._prefill_wrapper = self._get_prefill_wrapper() + state._cuda_wrapper = self._graph_decode_wrapper + state._wrapper = self._get_wrapper() return state def graph_capture_get_metadata_for_batch( @@ -235,9 +227,11 @@ def graph_capture_get_metadata_for_batch( data_type=kv_cache_dtype, q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, - decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() + cuda_wrapper=self._graph_decode_wrapper, + wrapper=self._wrapper) + # we don't need to pass logits and scale to begin_forward + # since in forward, it already gets it. + attn_metadata.begin_forward(None, None, (-1, -1)) return attn_metadata def get_graph_input_buffers(self, @@ -253,17 +247,40 @@ def prepare_graph_input_buffers(self, is_encoder_decoder_model: bool = False): return - def begin_forward(self, model_input): + def begin_forward(self, model_input, model): assert not self._is_graph_capturing state = self + + # sliding window needed for new kernel in begin_forward + sliding_window = self.runner.sliding_window + window_left = sliding_window[0] if sliding_window is not None else -1 + + try: + scale = model.model.layers[0].self_attn.attn.impl.scale + except AttributeError: + scale = None + logger.warning("Failed to retrieve 'scale'. \ + Check if 'self_attn.attn.impl' contains 'scale'.\ + Using default value of None") + + try: + logits_soft_cap = model.model.layers[ + 0].self_attn.attn.impl.logits_soft_cap + except AttributeError: + logits_soft_cap = None + logger.warning("Failed to retrieve 'logits_soft_cap'. \ + Check if 'self_attn.attn.impl' contains 'logits_soft_cap'. \ + Using default value of None") + if model_input.attn_metadata.use_cuda_graph: batch_size = model_input.input_tokens.shape[0] state = (self.runner.graph_runners[model_input.virtual_engine] [batch_size].attn_state) - model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( - ) - model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() - model_input.attn_metadata.begin_forward() + model_input.attn_metadata.cuda_wrapper = state._get_cuda_wrapper() + + model_input.attn_metadata.wrapper = state._get_wrapper() + model_input.attn_metadata.begin_forward(scale, logits_soft_cap, + window_left) @dataclass @@ -279,33 +296,35 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = True - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + cuda_wrapper: Optional[CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None - # Metadata for the prefill stage + # Metadata for wrapper seq_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None + second_level_query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None # used for GPU in-place advance_step seq_lens_tensor: Optional[torch.Tensor] = None block_table_bound: Optional[torch.Tensor] = None - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache + # Refer to: https://docs.flashinfer.ai/tutorials/kv_layout.html + # and: https://docs.flashinfer.ai/api/python/cascade.html + # Store shared prefix blocks of requests paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] + # Index pointers to the start of each shared block of KV-Cache + paged_kv_indptr: Optional[torch.Tensor] = None + # paged_kv_last_page_len is the length of the last page of the shared KVs paged_kv_last_page_len: Optional[torch.Tensor] = None + # Store the concatenated page indices of all requests for the second level + second_level_kv_indices: Optional[torch.Tensor] = None + # Index pointers to the start of each request's page indices + # in the second_level_kv_indices + second_level_kv_indptr: Optional[torch.Tensor] = None + # The length of the last page of each request in the second level + second_level_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads num_qo_heads: Optional[int] = None # The number of key/value heads @@ -331,69 +350,137 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - def begin_forward(self): - if self.num_prefill_tokens > 0: - if self.paged_kv_indices is None: - return + def begin_forward(self, scale: Optional[float], + logits_soft_cap: Optional[float], + window_left: Optional[int]): + if self.paged_kv_indices is None: + return - assert self.prefill_wrapper is not None - assert self.query_start_loc is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - assert self.block_table_bound is not None - assert self.seq_lens_tensor is not None - self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - batch_size = self.query_start_loc.shape[0] - 1 - assert batch_size >= 0 - # We will use flash attention for profiling to - # determine the number of blocks. Therefore, - # we don't need to prepare the input for flashinfer for profile run. - if not self.is_profile_run: - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - self.block_table_bound = self.block_table_bound.to(self.device) - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( - self.query_start_loc, - self.paged_kv_indptr[:self.num_prefills + 1], - self.paged_kv_indices, - self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) - if self.num_decode_tokens > 0: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None + assert self.wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + + if not self.is_profile_run: self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) - # handle model warmup path - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( - self.paged_kv_indptr[self.num_prefills:], - self.paged_kv_indices, - self.paged_kv_last_page_len[self.num_prefills:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - # kv-cache data type. - data_type=self.data_type, - # query data type. - q_data_type=self.q_data_type) + + if self.num_decode_tokens > 0: + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to( + self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + # Case 1: Prefill only + if self.num_prefill_tokens > 0 and self.num_decode_tokens == 0: + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + assert self.second_level_query_start_loc is not None + assert self.query_start_loc is not None + + self.second_level_kv_indices = self.second_level_kv_indices.to( # noqa + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( # noqa + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + self.wrapper.plan([ + self.query_start_loc[:self.num_prefills + 1], + self.second_level_query_start_loc[:self.num_prefills + 1] + ], [ + self.paged_kv_indptr[:self.num_prefills + 1], + self.second_level_kv_indptr[:self.num_prefills + 1] + ], [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len[:self.num_prefills], + self.second_level_kv_last_page_len[:self.num_prefills] + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap, + window_left=window_left) + + # Case 2: Decode only + elif self.num_prefill_tokens == 0 and self.num_decode_tokens > 0: + if not self.use_cuda_graph: + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + self.second_level_kv_indices = self.second_level_kv_indices.to( # noqa + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( # noqa + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + self.wrapper.plan([ + self.query_start_loc, self.second_level_query_start_loc + ], [ + self.paged_kv_indptr[self.num_prefills:], + self.second_level_kv_indptr[self.num_prefills:] + ], [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len[self.num_prefills:], + self.second_level_kv_last_page_len[self.num_prefills:] + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap, + window_left=window_left) + else: + assert self.cuda_wrapper is not None + self.cuda_wrapper.end_forward() + self.cuda_wrapper.begin_forward( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + # Case 3: Both prefill and decode (chunked prefill case) + else: + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + self.second_level_kv_indices = self.second_level_kv_indices.to( + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + + self.wrapper.plan( + [self.query_start_loc, self.second_level_query_start_loc], + [self.paged_kv_indptr, self.second_level_kv_indptr], + [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len, + self.second_level_kv_last_page_len + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap, + window_left=window_left) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -402,8 +489,8 @@ def asdict_zerocopy(self, skip_fields = set() # We need to skip the prefill/decode_wrapper field since it cannot be # broadcasted with nccl when TP is enabled. - skip_fields.add('prefill_wrapper') - skip_fields.add('decode_wrapper') + skip_fields.add('wrapper') + skip_fields.add('cuda_wrapper') return super().asdict_zerocopy(skip_fields) @property @@ -486,28 +573,26 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] + # Store the concatenated indices of shared prefix of the requests self.paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. + # Index pointers to the start of each shared blocks self.paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request + # The length of the last page of the shared kvs self.paged_kv_last_page_len: List[int] = [] + # Store concatenated page indices of requests for the second level + self.second_level_kv_indices: List[int] = [] + # Index pointers to the start of each request's page indices + self.second_level_kv_indptr: List[int] = [0] + # The length of the last page of each request in the second level + self.second_level_kv_last_page_len: List[int] = [] + self.total_blocks = 0 self.is_profile_run: bool = False def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, common_prefix: List[int], + use_cuda_graph: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -570,14 +655,18 @@ def _add_seq_group( return block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. + if use_cuda_graph: + self._update_cuda_wrapper_unique_kv_tensors( + block_table, seq_len) + else: + self._update_unique_kv_tensors(block_table, seq_len, + common_prefix) + + def _update_cuda_wrapper_unique_kv_tensors(self, block_table: List[int], + seq_len: int) -> None: + """ + Updates tensors for cuda decode wrapper + """ self.total_blocks += len(block_table) block_table_bound = seq_len // self.block_size + 1 \ if seq_len % self.block_size != 0 \ @@ -591,6 +680,71 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): last_page_len = self.block_size self.paged_kv_last_page_len.append(last_page_len) + def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, + common_prefix: List[int]) -> None: + """ + Updates the unique level tensors + """ + + shared_length = len(common_prefix) + self.total_blocks += (len(block_table) - shared_length) + block_table_bound = (seq_len) // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else (seq_len) // self.block_size + self.second_level_kv_indices.extend( + block_table[shared_length:block_table_bound]) + self.second_level_kv_indptr.append(self.second_level_kv_indptr[-1] + + (block_table_bound - shared_length)) + last_page_len = (seq_len) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.second_level_kv_last_page_len.append(last_page_len) + + def _update_shared_kv_tensors(self, common_prefix: List[int], + batch_size: int) -> None: + """ + Updates the shared level kv tensors + """ + if not common_prefix: + # if we don't have common prefixes, we only use the unique level + # so we fill the first level indices, indptr, last page len with 0s + # to conform with multilevel wrapper input requirements + self.paged_kv_indices.extend([0] * batch_size) + self.paged_kv_indptr.extend([0] * batch_size) + self.paged_kv_last_page_len.extend([0] * batch_size) + else: + self.total_blocks += len(common_prefix) + self.paged_kv_indices.extend(common_prefix) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + len(common_prefix)) + self.paged_kv_last_page_len.append(self.block_size) + + def get_shared_blocks_nums( + self, + inter_data_list: List["ModelInputForGPUBuilder.InterDataForSeqGroup"] + ) -> List[int]: + """ + Returns a list of consecutive shared blocks across sequence groups + """ + if len(inter_data_list) == 1: + return [] + + flattened_lists = [] + for data in inter_data_list: + if data.block_tables: + flattened_lists += list(data.block_tables.values()) + + common_prefix: List[int] = [] + for i, block_tuple in enumerate(zip(*flattened_lists)): + if all(block == block_tuple[0] for block in block_tuple): + if i > 0 and block_tuple[0] != common_prefix[-1] + 1: + break + common_prefix.append(block_tuple[0]) + else: + break + + return common_prefix + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. @@ -602,12 +756,22 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + # common_prefix = self.get_shared_blocks_nums( + # self.input_builder.inter_data_list + # ) + # FIXME: we set common_prefix to empty list now since + # shared level is not working yet. + common_prefix: List[int] = [] + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + self.input_builder.chunked_prefill_enabled, + common_prefix, use_captured_graph) - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 + if not use_captured_graph: + self._update_shared_kv_tensors(common_prefix, len(query_lens)) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -662,6 +826,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + second_level_query_start_loc = torch.zeros(query_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in @@ -671,10 +840,26 @@ def build(self, seq_lens: List[int], query_lens: List[int], dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + out=second_level_query_start_loc[1:]) + + if not common_prefix: + # if no common prefix, we only use the unique kv level, so + # we just set the first level query start loc the same as + # the second levels + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + else: + # when we use shared level of the multilevel wrapper + query_start_loc = torch.tensor( + [0, second_level_query_start_loc[-1]], + dtype=torch.int32, + device=device) if len(self.paged_kv_indptr) > 0: # extend to the maximum number of blocks as returned by the @@ -689,6 +874,16 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int) paged_kv_last_page_len_tensor = torch.tensor( self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + + second_level_kv_indices_tensor = torch.tensor( + self.second_level_kv_indices, device="cpu", dtype=torch.int) + second_level_kv_indptr_tensor = torch.tensor( + self.second_level_kv_indptr, device="cpu", dtype=torch.int) + second_level_kv_last_page_len_tensor = torch.tensor( + self.second_level_kv_last_page_len, + device="cpu", + dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - 1, device="cpu", @@ -698,6 +893,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None block_table_bound_tensor = None + second_level_kv_indices_tensor = None + second_level_kv_indptr_tensor = None + second_level_kv_last_page_len_tensor = None if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -718,6 +916,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, + second_level_kv_indptr=second_level_kv_indptr_tensor, + second_level_kv_indices=second_level_kv_indices_tensor, + second_level_kv_last_page_len=second_level_kv_last_page_len_tensor, + second_level_query_start_loc=second_level_query_start_loc, block_table_bound=block_table_bound_tensor, seq_lens_tensor=seq_lens_tensor, num_qo_heads=self.runner.model_config.get_num_attention_heads( @@ -811,8 +1013,6 @@ def forward( k_scale, v_scale, ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 if kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype) @@ -821,81 +1021,65 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( ) # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. + # Query for decode and prefill. + # KV is not needed because it is already cached. # QKV for prefill. decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] + prefill_query = query[:num_prefill_tokens] key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens + assert prefill_query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens window_left = window_size[0] if window_size is not None else -1 - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( - query, + if kv_cache.numel() == 0: + return flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ).view(num_tokens, hidden_size) + + assert attn_metadata.wrapper is not None + + if num_prefill_tokens > 0 and num_decode_tokens == 0: + output = attn_metadata.wrapper.run(prefill_query, kv_cache) + return output.view(num_tokens, hidden_size) + elif num_prefill_tokens == 0 and num_decode_tokens > 0: + if attn_metadata.cuda_wrapper is not None: + output = attn_metadata.cuda_wrapper.forward( + decode_query, kv_cache, + sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - causal=True, k_scale=k_scale, v_scale=v_scale, window_left=window_left) - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta is not None - assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.forward( - decode_query, - kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale, - window_left=window_left) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens + else: + assert attn_metadata.wrapper is not None + output = attn_metadata.wrapper.run(decode_query, kv_cache) + return output.view(num_tokens, hidden_size) else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + # Ensure chunked prefill with speculative decoding is not allowed + assert decode_query.shape[0] == 1, \ + """Chunked prefill batch does not work with + speculative decoding in FlashInfer backend.""" + + output = attn_metadata.wrapper.run(query, kv_cache) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..e7c5f8389a968 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -382,7 +382,7 @@ def prepare_graph_input_buffers( self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) - def begin_forward(self, model_input) -> None: + def begin_forward(self, model_input, model) -> None: return def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index fe5fd39f42ac9..4939ccf68ef09 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -1,3 +1,4 @@ +import weakref from typing import List, Optional import torch @@ -222,7 +223,8 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input) + self.attn_state.begin_forward(model_input, + weakref.proxy(self.model)) # Detect exec mode assert model_input.attn_metadata is not None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b545d1b28bd2..8302815d2690f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1639,7 +1639,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input) + self.attn_state.begin_forward(model_input, weakref.proxy(self.model)) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None