From ef7b292499ca578bb570603a5b9e93b3413d6d6e Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Fri, 1 Nov 2024 19:45:38 +0000 Subject: [PATCH 01/12] eager mode --- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/flashinfer.py | 452 ++++++++++++++++++------- vllm/attention/backends/utils.py | 2 +- vllm/spec_decode/draft_model_runner.py | 4 +- vllm/worker/model_runner.py | 2 +- 5 files changed, 333 insertions(+), 130 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9ea89eca01f5b..37266a53a6068 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from dataclasses import dataclass, fields from enum import Enum, auto +from torch import nn from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -184,7 +185,7 @@ def prepare_graph_input_buffers( ... @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + def begin_forward(self, model_input: "ModelRunnerInputBase", model: nn.Module) -> None: """Prepare state for forward pass.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5ea101ae0432f..fbd3ac9e2eb40 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,16 +3,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + from flashinfer.cascade import MultiLevelCascadeAttentionWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None + MultiLevelCascadeAttentionWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -102,8 +100,8 @@ def __init__(self, runner): self.runner = runner self._is_graph_capturing = False self._workspace_buffer = None - self._decode_wrapper = None - self._prefill_wrapper = None + self._cuda_wrapper = None + self._wrapper = None def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -113,25 +111,17 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") - return self._prefill_wrapper - - def _get_decode_wrapper(self): - if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._get_workspace_buffer(), - "NHD", - use_tensor_cores=use_tensor_cores) - return self._decode_wrapper + def _get_wrapper(self): + if self._wrapper is None: + self._wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._wrapper + + + def _get_cuda_wrapper(self): + if self._cuda_wrapper is not None: + return self._cuda_wrapper + return None @contextmanager def graph_capture(self, max_batch_size: int): @@ -171,8 +161,8 @@ def graph_clone(self, batch_size: int): assert self._is_graph_capturing state = self.__class__(self.runner) state._workspace_buffer = self._graph_decode_workspace_buffer - state._decode_wrapper = self._graph_decode_wrapper - state._prefill_wrapper = self._get_prefill_wrapper() + state._cuda_wrapper = self._graph_decode_wrapper + state._wrapper = self._get_wrapper() return state def graph_capture_get_metadata_for_batch( @@ -232,9 +222,11 @@ def graph_capture_get_metadata_for_batch( data_type=kv_cache_dtype, q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, - decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() + cuda_wrapper=self._graph_decode_wrapper, + wrapper=self._wrapper) + # we don't need to pass logits and scale to begin_forward + # since in forward, it already gets it. + attn_metadata.begin_forward(None, None) return attn_metadata def get_graph_input_buffers(self, @@ -250,17 +242,46 @@ def prepare_graph_input_buffers(self, is_encoder_decoder_model: bool = False): return - def begin_forward(self, model_input): + # def begin_forward(self, model_input): + # assert not self._is_graph_capturing + # state = self + # if model_input.attn_metadata.use_cuda_graph: + # batch_size = model_input.input_tokens.shape[0] + # state = (self.runner.graph_runners[model_input.virtual_engine] + # [batch_size].attn_state) + # model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + # ) + # model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + # model_input.attn_metadata.begin_forward() + + def begin_forward(self, model_input, model): assert not self._is_graph_capturing state = self + + try: + scale = getattr(model.model.layers[0].self_attn.attn.impl, "scale", + None) + except AttributeError as e: + raise AttributeError("Failed to retrieve 'scale'. \ + Check if 'self_attn.attn.impl' contains 'scale'.") from e + + try: + logits_soft_cap = getattr( + model.model.layers[0].self_attn.attn.impl, "logits_soft_cap", + None) + except AttributeError as e: + raise AttributeError("Failed to retrieve 'logits_soft_cap'. \ + Check if 'self_attn.attn.impl' contains 'logits_soft_cap'." + ) from e + if model_input.attn_metadata.use_cuda_graph: batch_size = model_input.input_tokens.shape[0] state = (self.runner.graph_runners[model_input.virtual_engine] [batch_size].attn_state) - model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( - ) - model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() - model_input.attn_metadata.begin_forward() + model_input.attn_metadata.cuda_wrapper = state._get_cuda_wrapper() + + model_input.attn_metadata.wrapper = state._get_wrapper() + model_input.attn_metadata.begin_forward(scale, logits_soft_cap) @dataclass @@ -276,33 +297,35 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = True - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - - # Metadata for the prefill stage + wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + cuda_wrapper: Optional[CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for wrapper seq_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None + second_level_query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None # used for GPU in-place advance_step seq_lens_tensor: Optional[torch.Tensor] = None block_table_bound: Optional[torch.Tensor] = None - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache + # Refer to: https://docs.flashinfer.ai/tutorials/kv_layout.html + # and: https://docs.flashinfer.ai/api/python/cascade.html + # Store shared prefix blocks of requests paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] + # Index pointers to the start of each shared block of KV-Cache + paged_kv_indptr: Optional[torch.Tensor] = None + # paged_kv_last_page_len is the length of the last page of the shared KVs paged_kv_last_page_len: Optional[torch.Tensor] = None + # Store the concatenated page indices of all requests for the second layer + second_layer_kv_indices: Optional[torch.Tensor] = None + # Index pointers to the start of each request's page indices + # in the second_layer_kv_indices + second_layer_kv_indptr: Optional[torch.Tensor] = None + # The length of the last page of each request in the second layer + second_layer_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads num_qo_heads: Optional[int] = None # The number of key/value heads @@ -328,16 +351,20 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - def begin_forward(self): + #TODO: NEED TO ADD CHUNKED PREFILL + def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float]): if self.num_prefill_tokens > 0: if self.paged_kv_indices is None: return - assert self.prefill_wrapper is not None + assert self.wrapper is not None assert self.query_start_loc is not None assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None + assert self.second_layer_kv_indices is not None + assert self.second_layer_kv_indptr is not None + assert self.second_layer_kv_last_page_len is not None assert self.block_table_bound is not None assert self.seq_lens_tensor is not None self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] @@ -347,50 +374,112 @@ def begin_forward(self): # determine the number of blocks. Therefore, # we don't need to prepare the input for flashinfer for profile run. if not self.is_profile_run: + print(self.paged_kv_indices) + print(self.second_layer_kv_indices) + print(self.paged_kv_indptr) + print(self.second_layer_kv_indptr) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( - self.query_start_loc, - self.paged_kv_indptr[:self.num_prefills + 1], - self.paged_kv_indices, - self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.second_layer_kv_indices = self.second_layer_kv_indices.to( + self.device) + self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + self.device) + self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + self.device) + + self.wrapper.plan( + [self.query_start_loc, self.second_level_query_start_loc], + [self.paged_kv_indptr, self.second_layer_kv_indptr], + [self.paged_kv_indices, self.second_layer_kv_indices], [ + self.paged_kv_last_page_len, + self.second_layer_kv_last_page_len + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) + if self.num_decode_tokens > 0: + if self.cuda_wrapper: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.cuda_wrapper is not None + self.cuda_wrapper.end_forward() + self.cuda_wrapper.begin_forward( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + return + assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None + assert self.second_layer_kv_indices is not None + assert self.second_layer_kv_indptr is not None + assert self.second_layer_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.second_layer_kv_indices = self.second_layer_kv_indices.to( + self.device) + self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + self.device) + self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + self.device) + # handle model warmup path if self.block_table_bound is not None: self.block_table_bound = self.block_table_bound.to(self.device) if self.seq_lens_tensor is not None: self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( - self.paged_kv_indptr[self.num_prefills:], - self.paged_kv_indices, - self.paged_kv_last_page_len[self.num_prefills:], + assert self.wrapper is not None + + self.wrapper.plan( + [self.query_start_loc, self.second_level_query_start_loc], + [self.paged_kv_indptr, self.second_layer_kv_indptr], + [self.paged_kv_indices, self.second_layer_kv_indices], [ + self.paged_kv_last_page_len, + self.second_layer_kv_last_page_len + ], self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - # kv-cache data type. - data_type=self.data_type, - # query data type. - q_data_type=self.q_data_type) + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -399,8 +488,8 @@ def asdict_zerocopy(self, skip_fields = set() # We need to skip the prefill/decode_wrapper field since it cannot be # broadcasted with nccl when TP is enabled. - skip_fields.add('prefill_wrapper') - skip_fields.add('decode_wrapper') + skip_fields.add('wrapper') + skip_fields.add('cuda_wrapper') return super().asdict_zerocopy(skip_fields) @property @@ -480,28 +569,25 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] + # Store the concatenated indices of shared prefix of the requests self.paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. + # Index pointers to the start of each shared blocks self.paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request + # The length of the last page of the shared kvs self.paged_kv_last_page_len: List[int] = [] + # Store concatenated page indices of requests for the second layer + self.second_layer_kv_indices: List[int] = [] + # Index pointers to the start of each request's page indices + self.second_layer_kv_indptr: List[int] = [0] + # The length of the last page of each request in the second layer + self.second_layer_kv_last_page_len: List[int] = [] + self.total_blocks = 0 self.is_profile_run: bool = False def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, common_prefix: List[int], use_cuda_graph: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -559,26 +645,88 @@ def _add_seq_group( return block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) + self._update_unique_kv_tensors(block_table, seq_len, common_prefix, use_cuda_graph) + - last_page_len = seq_len % self.block_size + def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, + common_prefix: List[int], use_cuda_graph: bool) -> None: + """ + Updates the unique level kv tensors + """ + if use_cuda_graph: + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + return + + shared_length = len(common_prefix) + self.total_blocks += (len(block_table) - shared_length) + block_table_bound = (seq_len) // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else (seq_len) // self.block_size + self.second_layer_kv_indices.extend( + block_table[shared_length:block_table_bound]) + self.second_layer_kv_indptr.append(self.second_layer_kv_indptr[-1] + + (block_table_bound - shared_length)) + last_page_len = (seq_len) % self.block_size if last_page_len == 0: last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) + self.second_layer_kv_last_page_len.append(last_page_len) + + def _update_shared_kv_tensors(self, common_prefix: List[int], + batch_size: int) -> None: + """ + Updates the shared level kv tensors + """ + if not common_prefix: + # if we don't have common prefixes, we only use the unique level + # so we fill the first level indices, indptr, last page len with 0s + # to conform with multilevel wrapper input requirements + self.paged_kv_indices.extend([0] * batch_size) + self.paged_kv_indptr.extend([0] * batch_size) + self.paged_kv_last_page_len.extend([0] * batch_size) + else: + # FIXME: this still doesn't work + self.total_blocks += len(common_prefix) + self.paged_kv_indices.extend(common_prefix) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + len(common_prefix)) + self.paged_kv_last_page_len.append(self.block_size) + + def get_shared_blocks_nums( + self, + inter_data_list: List["ModelInputForGPUBuilder.InterDataForSeqGroup"] + ) -> List[int]: + """ + Returns a list of consecutive shared blocks across sequence groups + """ + if len(inter_data_list) == 1: + return [] + + flattened_lists = [] + for data in inter_data_list: + if data.block_tables: + flattened_lists += list(data.block_tables.values()) + + common_prefix: List[int] = [] + for i, block_tuple in enumerate(zip(*flattened_lists)): + if all(block == block_tuple[0] for block in block_tuple): + if i > 0 and block_tuple[0] != common_prefix[-1] + 1: + break + common_prefix.append(block_tuple[0]) + else: + break + + return common_prefix def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): @@ -591,13 +739,23 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - + # common_prefix = self.get_shared_blocks_nums( + # self.input_builder.inter_data_list + # ) + # FIXME: we set common_prefix to empty list now since + # shared level is not working yet. + common_prefix: List[int] = [] device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + common_prefix, use_captured_graph) + + if not use_captured_graph: + self._update_shared_kv_tensors(common_prefix, len(query_lens)) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens decode_query_len = max(query_lens[self.num_prefills:], default=1) @@ -651,20 +809,41 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + second_level_query_start_loc = torch.zeros(query_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + out=second_level_query_start_loc[1:]) + + if not common_prefix: + # if no common prefix, we only use the unique kv level, so + # we just set the first level query start loc the same as + # the second levels + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + else: + # when we use shared level of the multilevel wrapper + query_start_loc = torch.tensor( + [0, second_level_query_start_loc[-1]], + dtype=torch.int32, + device=device) if len(self.paged_kv_indptr) > 0: # extend to the maximum number of blocks as returned by the # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) + self.second_layer_kv_indices.extend( + [0] * (self.total_blocks - len(self.second_layer_kv_indices))) paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", dtype=torch.int) @@ -673,6 +852,16 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int) paged_kv_last_page_len_tensor = torch.tensor( self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + + second_level_kv_indices_tensor = torch.tensor( + self.second_layer_kv_indices, device="cpu", dtype=torch.int) + second_level_kv_indptr_tensor = torch.tensor( + self.second_layer_kv_indptr, device="cpu", dtype=torch.int) + second_level_kv_last_page_len_tensor = torch.tensor( + self.second_layer_kv_last_page_len, + device="cpu", + dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - 1, device="cpu", @@ -682,6 +871,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None block_table_bound_tensor = None + second_level_kv_indices_tensor = None + second_level_kv_indptr_tensor = None + second_level_kv_last_page_len_tensor = None if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -701,6 +893,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, + second_layer_kv_indptr=second_level_kv_indptr_tensor, + second_layer_kv_indices=second_level_kv_indices_tensor, + second_layer_kv_last_page_len=second_level_kv_last_page_len_tensor, + second_level_query_start_loc=second_level_query_start_loc, block_table_bound=block_table_bound_tensor, seq_lens_tensor=seq_lens_tensor, num_qo_heads=self.runner.model_config.get_num_attention_heads( @@ -718,6 +914,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], is_profile_run=self.is_profile_run) + class FlashInferImpl(AttentionImpl): def __init__( @@ -874,19 +1071,24 @@ def unified_flash_infer( ) else: assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( - query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + assert prefill_meta.wrapper is not None + prefill_output = prefill_meta.wrapper.run( + query, kv_cache) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None - assert attn_metadata.decode_metadata.decode_wrapper is not None - decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( - decode_query, - kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale) + if attn_metadata.decode_metadata.cuda_wrapper is not None: + decode_output = attn_metadata.decode_metadata.cuda_wrapper.forward( + decode_query, + kv_cache, + sm_scale=softmax_scale, + logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) + else: + assert attn_metadata.decode_metadata.wrapper is not None + decode_output = attn_metadata.decode_metadata.wrapper.run( + decode_query, + kv_cache) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 32fccd0dfb496..fcb976fcae33f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -365,7 +365,7 @@ def prepare_graph_input_buffers( self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) - def begin_forward(self, model_input) -> None: + def begin_forward(self, model_input, model) -> None: return def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3aa999fcb9ebb..69d28ba4a2f7b 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -1,7 +1,7 @@ from typing import List, Optional import torch - +import weakref from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput @@ -246,7 +246,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input) + self.attn_state.begin_forward(model_input, weakref.proxy(self.model)) # Detect exec mode assert model_input.attn_metadata is not None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 891637dafbb14..905fb64f5d0d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1614,7 +1614,7 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input) + self.attn_state.begin_forward(model_input, weakref.proxy(self.model)) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None From 78f0cbec62395268294b64b5a262c3ec295ea390 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Fri, 1 Nov 2024 19:47:45 +0000 Subject: [PATCH 02/12] test --- tests/kernels/test_cascade.py | 380 ++++++++++++++++++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 tests/kernels/test_cascade.py diff --git a/tests/kernels/test_cascade.py b/tests/kernels/test_cascade.py new file mode 100644 index 0000000000000..742557bbef7d1 --- /dev/null +++ b/tests/kernels/test_cascade.py @@ -0,0 +1,380 @@ +from typing import List, Tuple + +import flashinfer +import pytest +import torch + +from vllm.utils import seed_everything + + +@pytest.mark.parametrize("beam_width", [16, 32]) +@pytest.mark.parametrize("seq_lens", [[(4096, 4096)]]) +@pytest.mark.parametrize("num_heads", [(16, 16)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_runs", [200]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [512]) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode() +def test_cascade_speedup(beam_width, seq_lens, num_heads, head_size, dtype, + block_size, num_runs, max_num_blocks_per_seq, + soft_cap): + """ + Compares the performance of flashinfer multilevel kernel and batch decode. + """ + + cascade_outputs, time_taken_cascade = run_multilevel_cascade_attention_wrapper( # noqa: E501 + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + seq_lens=seq_lens, + num_runs=num_runs, + block_size=block_size, + beam_width=beam_width, + num_levels=2, + max_num_blocks_per_seq=max_num_blocks_per_seq, + soft_cap=soft_cap, + ) + + batchdecode_outputs, time_taken_batchdecode = run_flashinfer_batchdecode_beam_search( # noqa: E501 + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + seq_lens=seq_lens, + num_runs=num_runs, + block_size=block_size, + beam_width=beam_width, + max_num_blocks_per_seq=max_num_blocks_per_seq, + soft_cap=soft_cap, + ) + + assert len(cascade_outputs) == len( + batchdecode_outputs), "Output length mismatch between the two methods." + + max_diff = 0 + + for cascade_output, batchdecode_output in zip(cascade_outputs, + batchdecode_outputs): + assert cascade_output.shape == batchdecode_output.shape, "Shape mismatch between outputs." # noqa: E501 + + isclose = torch.isclose(cascade_output, + batchdecode_output, + rtol=1e-2, + atol=1e-3) + if not isclose.all(): + diff = torch.abs(cascade_output - batchdecode_output) + current_max_diff = torch.max(diff).item() + max_diff = max(max_diff, current_max_diff) + + speedup = time_taken_batchdecode / time_taken_cascade + + assert speedup > 1.0, f"No speedup with cascade infer: {speedup}" + assert max_diff <= 1e-3, f"Max difference too large: {max_diff}" + + +def run_flashinfer_batchdecode_beam_search( + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + soft_cap: float, + seq_lens: list, + num_runs: int, + block_size: int, + beam_width: int, + max_num_blocks_per_seq: int, +) -> Tuple[List[torch.Tensor], float]: + torch.set_default_device("cuda") + seed_everything(0) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + + num_seqs = len(seq_lens) + kv_lens = [x[1] for x in seq_lens] + + num_blocks = max_num_blocks_per_seq * num_seqs * beam_width + + key_value_cache = torch.randn(num_blocks, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device='cuda').reshape( + num_blocks, 2, block_size, num_kv_heads, + head_size) + + workspace_size = 128 * 1024 * 1024 + workspace_buffer_decode = torch.empty(workspace_size, + dtype=torch.int8, + device='cuda') + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer_decode, "NHD") + + block_tables = torch.zeros((num_seqs * beam_width, max_num_blocks_per_seq), + dtype=torch.int32) + + block_offset = 0 + + for start_seq in range(num_seqs): + shared_len = kv_lens[start_seq] // block_size + + for i in range(start_seq * beam_width, (start_seq + 1) * beam_width): + block_tables[i, :shared_len] = torch.arange( + block_offset, block_offset + shared_len) + + block_offset += shared_len + + for i in range(beam_width): + beam_index = start_seq * beam_width + i + unique_start = block_offset + i + block_tables[beam_index, + shared_len:max_num_blocks_per_seq] = torch.arange( + unique_start, unique_start + + (max_num_blocks_per_seq - shared_len) * + beam_width, beam_width) + block_offset += (max_num_blocks_per_seq - shared_len) * beam_width + + cumulative_run_time = 0.0 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + outputs = [] + + # index of the next block that we append for each sequence + next_block_index = [num // block_size + 1 for num in kv_lens] + + kv_indptr: List[int] = [0] + kv_indices: List[List[int]] = [] + kv_last_page_lens: List[int] = [] + + for i in range(num_seqs * beam_width): + seq_len = kv_lens[i // beam_width] + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.append(list(block_tables[i, :num_blocks])) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + kv_indptr.append(kv_indptr[-1] + num_blocks) + + for step in range(num_runs): + torch.manual_seed(step) + + query = torch.randn( + num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, num_query_heads, + head_size) + + kv_indptr_tensor = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices_tensor = torch.cat([torch.tensor(kv) + for kv in kv_indices]).reshape(-1) + kv_last_page_lens_tensor = torch.tensor(kv_last_page_lens, + dtype=torch.int32) + + decode_wrapper.plan(kv_indptr_tensor, + kv_indices_tensor, + kv_last_page_lens_tensor, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype, + logits_soft_cap=soft_cap) + + start_event.record() + output = decode_wrapper.run(query, key_value_cache) + end_event.record() + torch.cuda.synchronize() + decode_time = start_event.elapsed_time(end_event) + cumulative_run_time += decode_time + + outputs.append(output.cpu()) + + if step % block_size == 0: + for i in range(beam_width * num_seqs): + kv_indices[i].append( + block_tables[i, next_block_index[i // beam_width]]) + + for i in range(len(next_block_index)): + next_block_index[i] += 1 + + for i in range(1, beam_width * num_seqs + 1): + kv_indptr[i] += i + kv_last_page_lens = [(prev + 1) % block_size or block_size + for prev in kv_last_page_lens] + + return outputs, cumulative_run_time + + +def run_multilevel_cascade_attention_wrapper( + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + seq_lens: list, + num_runs: int, + block_size: int, + beam_width: int, + num_levels: int, + max_num_blocks_per_seq: int, + soft_cap: float, +) -> Tuple[List[torch.Tensor], float]: + torch.set_default_device("cuda") + seed_everything(0) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + + num_seqs = len(seq_lens) + kv_lens = [x[1] for x in seq_lens] + + num_blocks = max_num_blocks_per_seq * num_seqs * beam_width + + key_value_cache = torch.randn(num_blocks, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device='cuda') + + workspace_size = 128 * 1024 * 1024 + workspace_buffer = torch.empty(workspace_size, + dtype=torch.uint8, + device='cuda') + wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( + num_levels, workspace_buffer, "NHD") + + block_tables = torch.zeros((num_seqs * beam_width, max_num_blocks_per_seq), + dtype=torch.int32) + block_offset = 0 + + for start_seq in range(num_seqs): + shared_len = kv_lens[start_seq] // block_size + + for i in range(start_seq * beam_width, (start_seq + 1) * beam_width): + block_tables[i, :shared_len] = torch.arange( + block_offset, block_offset + shared_len) + + block_offset += shared_len + + for i in range(beam_width): + beam_index = start_seq * beam_width + i + unique_start = block_offset + i + block_tables[beam_index, + shared_len:max_num_blocks_per_seq] = torch.arange( + unique_start, unique_start + + (max_num_blocks_per_seq - shared_len) * + beam_width, beam_width) + block_offset += (max_num_blocks_per_seq - shared_len) * beam_width + + qo_indptr_arr = [ + torch.tensor([0, beam_width * num_seqs], + dtype=torch.int32, + device='cuda'), + torch.arange(beam_width * num_seqs + 1, + dtype=torch.int32, + device="cuda") + ] + + shared_kv_page_indptr: List[int] = [0] + unique_kv_page_indptr: List[int] = [0] + shared_kv_page_indices: List[List[int]] = [] + unique_kv_page_indices: List[List[int]] = [] + shared_kv_last_page_len: List[int] = [] + unique_kv_last_page_len: List[int] = [] + + query = torch.arange(num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, + num_query_heads, head_size) + + # Fill the shared metadatas + for i in range(num_seqs): + seq_len = kv_lens[i // beam_width] + num_shared_blocks = ( + seq_len) // block_size if seq_len % block_size == 0 else ( + (seq_len) // block_size) + 1 + shared_kv_page_indices.append(list( + block_tables[i, :num_shared_blocks])) + shared_kv_page_indptr.append(shared_kv_page_indptr[-1] + + num_shared_blocks) + shared_kv_len = seq_len % block_size + if shared_kv_len == 0: + shared_kv_len = block_size + shared_kv_last_page_len.append(shared_kv_len) + + for i in range(num_seqs * beam_width): + unique_kv_page_indices.append([]) + unique_kv_page_indptr.append(unique_kv_page_indptr[-1]) + unique_kv_last_page_len.append(block_size) + + shared_kv_page_indptr = torch.tensor(shared_kv_page_indptr, + dtype=torch.int32, + device='cuda') + shared_kv_page_indices = torch.cat( + [torch.tensor(x) for x in shared_kv_page_indices]).reshape(-1) + shared_kv_last_page_len = torch.tensor(shared_kv_last_page_len, + dtype=torch.int32, + device='cuda') + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + cumulative_run_time = 0.0 + + outputs = [] + + # index of the next block that we append for each sequence + next_block_index = [num // block_size + 1 for num in kv_lens] + + for step in range(num_runs): + torch.manual_seed(step) + query = torch.randn( + num_seqs * beam_width * num_query_heads * head_size, + dtype=dtype, + device='cuda').reshape(num_seqs * beam_width, num_query_heads, + head_size) + + wrapper.plan(qo_indptr_arr, [ + shared_kv_page_indptr, + torch.tensor( + unique_kv_page_indptr, dtype=torch.int32, device='cuda') + ], [ + shared_kv_page_indices, + torch.cat([torch.tensor(x) + for x in unique_kv_page_indices]).reshape(-1) + ], [ + shared_kv_last_page_len, + torch.tensor( + unique_kv_last_page_len, dtype=torch.int32, device='cuda') + ], + num_query_heads, + num_kv_heads, + head_size, + block_size, + logits_soft_cap=soft_cap) + + start_event.record() + output = wrapper.run(query, key_value_cache) + end_event.record() + torch.cuda.synchronize() + + cumulative_run_time += start_event.elapsed_time(end_event) + + outputs.append(output.cpu()) + + if step % block_size == 0: + for i in range(beam_width * num_seqs): + unique_kv_page_indices[i].append( + block_tables[i, next_block_index[i // beam_width]]) + for i in range(len(next_block_index)): + next_block_index[i] += 1 + for i in range(1, beam_width * num_seqs + 1): + unique_kv_page_indptr[i] += i + + unique_kv_last_page_len = [(x + 1) % block_size or block_size + for x in unique_kv_last_page_len] + + return outputs, cumulative_run_time \ No newline at end of file From 3f6c458e15eb479e1d12e9338f6412bd2f80ba32 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Fri, 1 Nov 2024 19:52:29 +0000 Subject: [PATCH 03/12] clean up --- vllm/attention/backends/flashinfer.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index fbd3ac9e2eb40..0782228df03e0 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -242,18 +242,6 @@ def prepare_graph_input_buffers(self, is_encoder_decoder_model: bool = False): return - # def begin_forward(self, model_input): - # assert not self._is_graph_capturing - # state = self - # if model_input.attn_metadata.use_cuda_graph: - # batch_size = model_input.input_tokens.shape[0] - # state = (self.runner.graph_runners[model_input.virtual_engine] - # [batch_size].attn_state) - # model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( - # ) - # model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() - # model_input.attn_metadata.begin_forward() - def begin_forward(self, model_input, model): assert not self._is_graph_capturing state = self @@ -374,10 +362,6 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] # determine the number of blocks. Therefore, # we don't need to prepare the input for flashinfer for profile run. if not self.is_profile_run: - print(self.paged_kv_indices) - print(self.second_layer_kv_indices) - print(self.paged_kv_indptr) - print(self.second_layer_kv_indptr) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) From 84b8be9048ea403c1301bf1833619fd9c18fde00 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Mon, 4 Nov 2024 23:50:58 +0000 Subject: [PATCH 04/12] cleanup --- docs/source/dev/pooling_params.rst | 5 + ...i_chat_completion_client_for_multimodal.py | 236 ++++++++++++++++++ ...ai_chat_embedding_client_for_multimodal.py | 33 +++ examples/template_vlm2vec.jinja | 16 ++ .../openai/test_vision_embedding.py | 99 ++++++++ 5 files changed, 389 insertions(+) create mode 100644 docs/source/dev/pooling_params.rst create mode 100644 examples/openai_chat_completion_client_for_multimodal.py create mode 100644 examples/openai_chat_embedding_client_for_multimodal.py create mode 100644 examples/template_vlm2vec.jinja create mode 100644 tests/entrypoints/openai/test_vision_embedding.py diff --git a/docs/source/dev/pooling_params.rst b/docs/source/dev/pooling_params.rst new file mode 100644 index 0000000000000..334e0287aff09 --- /dev/null +++ b/docs/source/dev/pooling_params.rst @@ -0,0 +1,5 @@ +Pooling Parameters +================== + +.. autoclass:: vllm.PoolingParams + :members: diff --git a/examples/openai_chat_completion_client_for_multimodal.py b/examples/openai_chat_completion_client_for_multimodal.py new file mode 100644 index 0000000000000..0ec4f71dddf93 --- /dev/null +++ b/examples/openai_chat_completion_client_for_multimodal.py @@ -0,0 +1,236 @@ +"""An example showing how to use vLLM to serve multimodal models +and run online inference with OpenAI client. + +Launch the vLLM server with the following command: + +(single image inference with Llava) +vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + +(multi-image inference with Phi-3.5-vision-instruct) +vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + +(audio inference with Ultravox) +vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 +""" +import base64 + +import requests +from openai import OpenAI + +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + +# Text-only inference +def run_text_only() -> None: + chat_completion = client.chat.completions.create( + messages=[{ + "role": "user", + "content": "What's the capital of France?" + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion.choices[0].message.content + print("Chat completion output:", result) + + +# Single-image input inference +def run_single_image() -> None: + + ## Use image url in the payload + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from image url:", result) + + ## Use base64 encoded image in the payload + image_base64 = encode_base64_content_from_url(image_url) + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from base64 encoded image:", result) + + +# Multi-image input inference +def run_multi_image() -> None: + image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" + image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output:", result) + + +# Audio input inference +def run_audio() -> None: + # Any format supported by librosa is supported + audio_url = AudioAsset("winning_call").url + + # Use audio url in the payload + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from audio url:", result) + + audio_base64 = encode_base64_content_from_url(audio_url) + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": f"data:audio/ogg;base64,{audio_base64}" + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from base64 encoded audio:", result) + + +example_function_map = { + "text-only": run_text_only, + "single-image": run_single_image, + "multi-image": run_multi_image, + "audio": run_audio, +} + + +def main(args) -> None: + chat_type = args.chat_type + example_function_map[chat_type]() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using OpenAI client for online inference with ' + 'multimodal language models served with vLLM.') + parser.add_argument( + '--chat-type', + '-c', + type=str, + default="single-image", + choices=["text-only", "single-image", "multi-image", "audio"], + help='Conversation type with multimodal data.') + args = parser.parse_args() + main(args) diff --git a/examples/openai_chat_embedding_client_for_multimodal.py b/examples/openai_chat_embedding_client_for_multimodal.py new file mode 100644 index 0000000000000..effb588e1387f --- /dev/null +++ b/examples/openai_chat_embedding_client_for_multimodal.py @@ -0,0 +1,33 @@ +import requests + +image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + +response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": + "TIGER-Lab/VLM2Vec-Full", + "messages": [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Represent the given image." + }, + ], + }], + "encoding_format": + "float", + }, +) +response.raise_for_status() +response_json = response.json() + +print("Embedding output:", response_json["data"][0]["embedding"]) diff --git a/examples/template_vlm2vec.jinja b/examples/template_vlm2vec.jinja new file mode 100644 index 0000000000000..489b99604af38 --- /dev/null +++ b/examples/template_vlm2vec.jinja @@ -0,0 +1,16 @@ +{%- if messages | length > 1 -%} + {{ raise_exception('Embedding models should only embed one message at a time') }} +{%- endif -%} + +{% set vars = namespace(parts=[], next_image_id=1) %} +{%- for message in messages -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set vars.parts = vars.parts + [content['text']] %} + {%- elif content['type'] == 'image' -%} + {%- set vars.parts = vars.parts + ['<|image_{i:d}|>'.format(i=vars.next_image_id)] %} + {%- set vars.next_image_id = vars.next_image_id + 1 %} + {%- endif -%} + {%- endfor -%} +{%- endfor -%} +{{ vars.parts | join(' ') }} diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py new file mode 100644 index 0000000000000..d0c43b47bf0af --- /dev/null +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -0,0 +1,99 @@ +from typing import Dict + +import pytest +import pytest_asyncio +import requests + +from vllm.multimodal.utils import encode_image_base64, fetch_image + +from ...utils import VLLM_PATH, RemoteOpenAIServer + +MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" +MAXIMUM_IMAGES = 2 + +vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja" +assert vlm2vec_jinja_path.exists() + +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embedding", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "5", + "--enforce-eager", + "--trust-remote-code", + "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}", + "--chat-template", + str(vlm2vec_jinja_path), + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> Dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, + image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Represent the given image." + }, + ], + }] + + response = requests.post(server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "encoding_format": "float" + }) + response.raise_for_status() + + embeddings = response.json() + assert embeddings["id"] is not None + assert len(embeddings["data"]) == 1 + assert len(embeddings["data"][0]["embedding"]) == 3072 + assert embeddings["usage"]["completion_tokens"] == 0 + assert embeddings["usage"]["prompt_tokens"] == 762 + assert embeddings["usage"]["total_tokens"] == 762 From 7babe4f09247860a4eb3ede3f05828da0747a8af Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Mon, 4 Nov 2024 23:53:57 +0000 Subject: [PATCH 05/12] cleanup code --- vllm/attention/backends/abstract.py | 5 +-- vllm/attention/backends/flashinfer.py | 44 +++++++++++++------------- vllm/spec_decode/draft_model_runner.py | 6 ++-- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 37266a53a6068..144eeaf34139e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,11 +2,11 @@ from contextlib import contextmanager from dataclasses import dataclass, fields from enum import Enum, auto -from torch import nn from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) import torch +from torch import nn if TYPE_CHECKING: from vllm.worker.model_runner_base import (ModelRunnerBase, @@ -185,7 +185,8 @@ def prepare_graph_input_buffers( ... @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase", model: nn.Module) -> None: + def begin_forward(self, model_input: "ModelRunnerInputBase", + model: nn.Module) -> None: """Prepare state for forward pass.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 0782228df03e0..704d897cfcb03 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.cascade import MultiLevelCascadeAttentionWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -117,7 +117,6 @@ def _get_wrapper(self): 2, self._get_workspace_buffer(), "NHD") return self._wrapper - def _get_cuda_wrapper(self): if self._cuda_wrapper is not None: return self._cuda_wrapper @@ -224,7 +223,7 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, cuda_wrapper=self._graph_decode_wrapper, wrapper=self._wrapper) - # we don't need to pass logits and scale to begin_forward + # we don't need to pass logits and scale to begin_forward # since in forward, it already gets it. attn_metadata.begin_forward(None, None) return attn_metadata @@ -287,7 +286,7 @@ class FlashInferMetadata(AttentionMetadata): wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None cuda_wrapper: Optional[CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None - + # Metadata for wrapper seq_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None @@ -339,8 +338,9 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - #TODO: NEED TO ADD CHUNKED PREFILL - def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float]): + #TODO: NEED TO ADD CHUNKED PREFILL + def begin_forward(self, scale: Optional[float], + logits_soft_cap: Optional[float]): if self.num_prefill_tokens > 0: if self.paged_kv_indices is None: return @@ -389,7 +389,7 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] causal=True, sm_scale=scale, logits_soft_cap=logits_soft_cap) - + if self.num_decode_tokens > 0: if self.cuda_wrapper: assert self.paged_kv_indices is not None @@ -401,7 +401,8 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] self.device) # handle model warmup path if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) + self.block_table_bound = self.block_table_bound.to( + self.device) if self.seq_lens_tensor is not None: self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) @@ -421,7 +422,7 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] data_type=self.data_type, # query data type. q_data_type=self.q_data_type) - + return assert self.paged_kv_indices is not None @@ -571,7 +572,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, common_prefix: List[int], use_cuda_graph: bool): + chunked_prefill_enabled: bool, common_prefix: List[int], + use_cuda_graph: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -629,11 +631,12 @@ def _add_seq_group( return block_table = block_tables[seq_id] - self._update_unique_kv_tensors(block_table, seq_len, common_prefix, use_cuda_graph) - + self._update_unique_kv_tensors(block_table, seq_len, common_prefix, + use_cuda_graph) def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, - common_prefix: List[int], use_cuda_graph: bool) -> None: + common_prefix: List[int], + use_cuda_graph: bool) -> None: """ Updates the unique level kv tensors """ @@ -651,7 +654,7 @@ def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, last_page_len = self.block_size self.paged_kv_last_page_len.append(last_page_len) return - + shared_length = len(common_prefix) self.total_blocks += (len(block_table) - shared_length) block_table_bound = (seq_len) // self.block_size + 1 \ @@ -660,7 +663,7 @@ def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, self.second_layer_kv_indices.extend( block_table[shared_length:block_table_bound]) self.second_layer_kv_indptr.append(self.second_layer_kv_indptr[-1] + - (block_table_bound - shared_length)) + (block_table_bound - shared_length)) last_page_len = (seq_len) % self.block_size if last_page_len == 0: last_page_len = self.block_size @@ -685,7 +688,7 @@ def _update_shared_kv_tensors(self, common_prefix: List[int], self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + len(common_prefix)) self.paged_kv_last_page_len.append(self.block_size) - + def get_shared_blocks_nums( self, inter_data_list: List["ModelInputForGPUBuilder.InterDataForSeqGroup"] @@ -736,7 +739,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, common_prefix, use_captured_graph) - + if not use_captured_graph: self._update_shared_kv_tensors(common_prefix, len(query_lens)) @@ -898,7 +901,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], is_profile_run=self.is_profile_run) - class FlashInferImpl(AttentionImpl): def __init__( @@ -1056,8 +1058,7 @@ def unified_flash_infer( else: assert prefill_meta is not None assert prefill_meta.wrapper is not None - prefill_output = prefill_meta.wrapper.run( - query, kv_cache) + prefill_output = prefill_meta.wrapper.run(query, kv_cache) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None if attn_metadata.decode_metadata.cuda_wrapper is not None: @@ -1071,8 +1072,7 @@ def unified_flash_infer( else: assert attn_metadata.decode_metadata.wrapper is not None decode_output = attn_metadata.decode_metadata.wrapper.run( - decode_query, - kv_cache) + decode_query, kv_cache) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 69d28ba4a2f7b..c18a55a9cbfbf 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -1,7 +1,8 @@ +import weakref from typing import List, Optional import torch -import weakref + from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput @@ -246,7 +247,8 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input, weakref.proxy(self.model)) + self.attn_state.begin_forward(model_input, + weakref.proxy(self.model)) # Detect exec mode assert model_input.attn_metadata is not None From a908e7714db199cf85905ecf26a695dc36456d9e Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Tue, 19 Nov 2024 22:20:11 +0000 Subject: [PATCH 06/12] adds chunked prefill support, but haven't cleaned up code --- vllm/attention/backends/flashinfer.py | 576 ++++++++++++++++++-------- 1 file changed, 405 insertions(+), 171 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index f16eb479f0b10..f2a989f22060c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -343,49 +343,200 @@ def __post_init__(self): f"received {self.head_dim}.") #TODO: NEED TO ADD CHUNKED PREFILL - def begin_forward(self, scale: Optional[float], - logits_soft_cap: Optional[float]): - if self.num_prefill_tokens > 0: - if self.paged_kv_indices is None: - return + # def begin_forward(self, scale: Optional[float], + # logits_soft_cap: Optional[float]): + # if self.num_prefill_tokens > 0: + # if self.paged_kv_indices is None: + # return + + # assert self.wrapper is not None + # assert self.query_start_loc is not None + # assert self.paged_kv_indices is not None + # assert self.paged_kv_indptr is not None + # assert self.paged_kv_last_page_len is not None + # assert self.second_layer_kv_indices is not None + # assert self.second_layer_kv_indptr is not None + # assert self.second_layer_kv_last_page_len is not None + # assert self.block_table_bound is not None + # assert self.seq_lens_tensor is not None + # self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + # self.second_level_query_start_loc = self.second_level_query_start_loc[:self.num_prefills+1] + # batch_size = self.query_start_loc.shape[0] - 1 + # assert batch_size >= 0 + # # We will use flash attention for profiling to + # # determine the number of blocks. Therefore, + # # we don't need to prepare the input for flashinfer for profile run. + # if not self.is_profile_run: + # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + # self.device) + # self.block_table_bound = self.block_table_bound.to(self.device) + # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + # self.paged_kv_indices = self.paged_kv_indices.to(self.device) + # self.second_layer_kv_indices = self.second_layer_kv_indices.to( + # self.device) + # self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + # self.device) + # self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + # self.device) + + # # print("Batch Size", batch_size) + # # print("NUM PREFILLS", self.num_prefills) + # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) + # # print("QUERY LOC", self.query_start_loc) + # # print("SECOND QUERY LOC", self.second_level_query_start_loc) + # # print("SECOND LAYER PAGED KV INDPTR", self.second_layer_kv_indptr[:self.num_prefills+1]) + + # self.wrapper.plan( + # [self.query_start_loc, self.second_level_query_start_loc], + # [self.paged_kv_indptr[:self.num_prefills+1], self.second_layer_kv_indptr[:self.num_prefills+1]], + # [self.paged_kv_indices, self.second_layer_kv_indices], [ + # self.paged_kv_last_page_len[:self.num_prefills], + # self.second_layer_kv_last_page_len[:self.num_prefills] + # ], + # self.num_qo_heads, + # self.num_kv_heads, + # self.head_dim, + # self.page_size, + # causal=True, + # sm_scale=scale, + # logits_soft_cap=logits_soft_cap) + + # if self.num_decode_tokens > 0: + # if self.cuda_wrapper: + # assert self.paged_kv_indices is not None + # assert self.paged_kv_indptr is not None + # assert self.paged_kv_last_page_len is not None + # self.paged_kv_indices = self.paged_kv_indices.to(self.device) + # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + # self.device) + # # handle model warmup path + # if self.block_table_bound is not None: + # self.block_table_bound = self.block_table_bound.to( + # self.device) + # if self.seq_lens_tensor is not None: + # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + # # print("Batch Size", batch_size) + # # print("NUM PREFILLS", self.num_prefills) + # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) + # # print("QUERY LOC", self.query_start_loc) + # # print("PAGED KV INDICES", self.paged_kv_indices) + + # assert self.cuda_wrapper is not None + # self.cuda_wrapper.end_forward() + # self.cuda_wrapper.begin_forward( + # self.paged_kv_indptr[self.num_prefills:], + # self.paged_kv_indices, + # self.paged_kv_last_page_len[self.num_prefills:], + # self.num_qo_heads, + # self.num_kv_heads, + # self.head_dim, + # self.page_size, + # # Disable flashinfer's pos encoding and use vllm's rope. + # pos_encoding_mode="NONE", + # # kv-cache data type. + # data_type=self.data_type, + # # query data type. + # q_data_type=self.q_data_type) + + # return + + # assert self.paged_kv_indices is not None + # assert self.paged_kv_indptr is not None + # assert self.paged_kv_last_page_len is not None + # assert self.second_layer_kv_indices is not None + # assert self.second_layer_kv_indptr is not None + # assert self.second_layer_kv_last_page_len is not None + + # self.paged_kv_indices = self.paged_kv_indices.to(self.device) + # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + # self.device) + # self.second_layer_kv_indices = self.second_layer_kv_indices.to( + # self.device) + # self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + # self.device) + # self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + # self.device) + + # # handle model warmup path + # if self.block_table_bound is not None: + # self.block_table_bound = self.block_table_bound.to(self.device) + # if self.seq_lens_tensor is not None: + # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + # assert self.wrapper is not None + + # # print("Batch Size", batch_size) + # # print("NUM PREFILLS", self.num_prefills) + # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) + # # print("SECOND LAYER PAGED KV INDPTR", self.second_layer_kv_indptr[:self.num_prefills+1]) + # # print("QUERY LOC", self.query_start_loc) + # # print("SECOND QUERY LOC", self.second_level_query_start_loc) + # # print("PAGED KV INDICES", self.paged_kv_indices) + # # print("SECOND PAGED KV INDICES", self.second_layer_kv_indices) + + # self.wrapper.plan( + # [self.query_start_loc, self.second_level_query_start_loc], + # [self.paged_kv_indptr[self.num_prefills:], self.second_layer_kv_indptr[self.num_prefills:]], + # [self.paged_kv_indices, self.second_layer_kv_indices], [ + # self.paged_kv_last_page_len[self.num_prefills:], + # self.second_layer_kv_last_page_len[self.num_prefills:] + # ], + # self.num_qo_heads, + # self.num_kv_heads, + # self.head_dim, + # self.page_size, + # causal=True, + # sm_scale=scale, + # logits_soft_cap=logits_soft_cap) + + def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float]): + if self.paged_kv_indices is None: + return + + assert self.wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None - assert self.wrapper is not None - assert self.query_start_loc is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None + if not self.use_cuda_graph: assert self.second_layer_kv_indices is not None assert self.second_layer_kv_indptr is not None assert self.second_layer_kv_last_page_len is not None - assert self.block_table_bound is not None - assert self.seq_lens_tensor is not None - self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - batch_size = self.query_start_loc.shape[0] - 1 - assert batch_size >= 0 - # We will use flash attention for profiling to - # determine the number of blocks. Therefore, - # we don't need to prepare the input for flashinfer for profile run. - if not self.is_profile_run: - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) + + print("DO WE USE CUDA GRAPH?", self.use_cuda_graph) + # Skip device transfer for profile run + if not self.is_profile_run: + # Move tensors to device + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(self.device) + + if not self.use_cuda_graph: + self.second_layer_kv_indices = self.second_layer_kv_indices.to(self.device) + self.second_layer_kv_indptr = self.second_layer_kv_indptr.to(self.device) + self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to(self.device) + + if self.block_table_bound is not None: self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.second_layer_kv_indices = self.second_layer_kv_indices.to( - self.device) - self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( - self.device) - self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 - self.device) + # Case 1: Prefill only + if self.num_prefill_tokens > 0 and self.num_decode_tokens == 0: self.wrapper.plan( - [self.query_start_loc, self.second_level_query_start_loc], - [self.paged_kv_indptr, self.second_layer_kv_indptr], - [self.paged_kv_indices, self.second_layer_kv_indices], [ - self.paged_kv_last_page_len, - self.second_layer_kv_last_page_len - ], + [self.query_start_loc[:self.num_prefills + 1], + self.second_level_query_start_loc[:self.num_prefills + 1]], + [self.paged_kv_indptr[:self.num_prefills + 1], + self.second_layer_kv_indptr[:self.num_prefills + 1]], + [self.paged_kv_indices, + self.second_layer_kv_indices], + [self.paged_kv_last_page_len[:self.num_prefills], + self.second_layer_kv_last_page_len[:self.num_prefills]], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -393,75 +544,55 @@ def begin_forward(self, scale: Optional[float], causal=True, sm_scale=scale, logits_soft_cap=logits_soft_cap) - - if self.num_decode_tokens > 0: - if self.cuda_wrapper: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - # handle model warmup path - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to( - self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - assert self.cuda_wrapper is not None - self.cuda_wrapper.end_forward() - self.cuda_wrapper.begin_forward( - self.paged_kv_indptr[self.num_prefills:], - self.paged_kv_indices, - self.paged_kv_last_page_len[self.num_prefills:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - # kv-cache data type. - data_type=self.data_type, - # query data type. - q_data_type=self.q_data_type) - - return - - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - assert self.second_layer_kv_indices is not None - assert self.second_layer_kv_indptr is not None - assert self.second_layer_kv_last_page_len is not None - - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - self.second_layer_kv_indices = self.second_layer_kv_indices.to( - self.device) - self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( - self.device) - self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 - self.device) - - # handle model warmup path - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - assert self.wrapper is not None - - self.wrapper.plan( - [self.query_start_loc, self.second_level_query_start_loc], - [self.paged_kv_indptr, self.second_layer_kv_indptr], - [self.paged_kv_indices, self.second_layer_kv_indices], [ - self.paged_kv_last_page_len, - self.second_layer_kv_last_page_len - ], + + # Case 2: Decode only + elif self.num_prefill_tokens == 0 and self.num_decode_tokens > 0: + if not self.use_cuda_graph: + self.wrapper.plan( + [self.query_start_loc, + self.second_level_query_start_loc], + [self.paged_kv_indptr[self.num_prefills:], + self.second_layer_kv_indptr[self.num_prefills:]], + [self.paged_kv_indices, + self.second_layer_kv_indices], + [self.paged_kv_last_page_len[self.num_prefills:], + self.second_layer_kv_last_page_len[self.num_prefills:]], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) + else: + assert self.cuda_wrapper is not None + self.cuda_wrapper.end_forward() + self.cuda_wrapper.begin_forward( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type + ) + # Case 3: Both prefill and decode (chunked prefill case) + else: + self.wrapper.plan( + [self.query_start_loc, + self.second_level_query_start_loc], + [self.paged_kv_indptr, + self.second_layer_kv_indptr], + [self.paged_kv_indices, + self.second_layer_kv_indices], + [self.paged_kv_last_page_len, + self.second_layer_kv_last_page_len], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -846,8 +977,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], if len(self.paged_kv_indptr) > 0: # extend to the maximum number of blocks as returned by the # scheduler - self.second_layer_kv_indices.extend( - [0] * (self.total_blocks - len(self.second_layer_kv_indices))) + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", dtype=torch.int) @@ -984,6 +1115,127 @@ def forward( ) +# def unified_flash_infer( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# num_heads: int, +# head_size: int, +# num_kv_heads: int, +# kv_cache: torch.Tensor, +# kv_cache_dtype: str, +# k_scale: float, +# v_scale: float, +# softmax_scale: float, +# window_size: Optional[List[int]] = None, +# alibi_slopes: Optional[torch.Tensor] = None, +# logits_soft_cap: Optional[float] = None, +# ) -> torch.Tensor: + +# current_metadata = get_forward_context() +# assert current_metadata is not None +# assert isinstance(current_metadata, FlashInferMetadata) +# attn_metadata: FlashInferMetadata = current_metadata + +# num_tokens, hidden_size = query.shape +# query = query.view(-1, num_heads, head_size) +# key = key.view(-1, num_kv_heads, head_size) +# value = value.view(-1, num_kv_heads, head_size) + +# if kv_cache.numel() > 0: +# # Use the same reshape and cache kernel as flash attention. +# ops.reshape_and_cache_flash( +# key, +# value, +# kv_cache[:, 0], +# kv_cache[:, 1], +# attn_metadata.slot_mapping.flatten(), +# kv_cache_dtype, +# k_scale, +# v_scale, +# ) +# # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 +# # to process the cache when the kv_cache_dtype is fp8 +# if kv_cache_dtype.startswith("fp8"): +# torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( +# kv_cache_dtype) +# kv_cache = kv_cache.view(torch_dtype) + +# num_prefill_tokens = attn_metadata.num_prefill_tokens +# num_decode_tokens = attn_metadata.num_decode_tokens +# assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ +# f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa +# assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ +# f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa +# query = query.contiguous() # Flashinfer requires query to be contiguous +# # Query for decode. KV is not needed because it is already cached. +# # QKV for prefill. +# decode_query = query[num_prefill_tokens:] +# query = query[:num_prefill_tokens] + +# key = key[:num_prefill_tokens] +# value = value[:num_prefill_tokens] + +# assert query.shape[0] == num_prefill_tokens +# assert decode_query.shape[0] == num_decode_tokens + +# prefill_output: Optional[torch.Tensor] = None +# decode_output: Optional[torch.Tensor] = None +# if prefill_meta := attn_metadata.prefill_metadata: +# # We will use flash attention for prefill +# # when kv_cache is not provided. +# # This happens when vllm runs the profiling to +# # determine the number of blocks. +# if kv_cache.numel() == 0: +# prefill_output = flash_attn_varlen_func( +# q=query, +# k=key, +# v=value, +# cu_seqlens_q=prefill_meta.seq_start_loc, +# cu_seqlens_k=prefill_meta.seq_start_loc, +# max_seqlen_q=prefill_meta.max_prefill_seq_len, +# max_seqlen_k=prefill_meta.max_prefill_seq_len, +# softmax_scale=softmax_scale, +# causal=True, +# window_size=window_size, +# alibi_slopes=alibi_slopes, +# ) +# else: +# assert prefill_meta is not None +# assert prefill_meta.wrapper is not None +# prefill_output = prefill_meta.wrapper.run(query, kv_cache) +# if decode_meta := attn_metadata.decode_metadata: +# assert attn_metadata.decode_metadata is not None +# if attn_metadata.decode_metadata.cuda_wrapper is not None: +# decode_output = attn_metadata.decode_metadata.cuda_wrapper.forward( +# decode_query, +# kv_cache, +# sm_scale=softmax_scale, +# logits_soft_cap=logits_soft_cap, +# k_scale=k_scale, +# v_scale=v_scale) +# else: +# assert attn_metadata.decode_metadata.wrapper is not None +# decode_output = attn_metadata.decode_metadata.wrapper.run( +# decode_query, kv_cache) + +# if prefill_output is None and decode_output is not None: +# # Decode only batch. +# output, num_tokens = decode_output, num_decode_tokens +# elif decode_output is None and prefill_output is not None: +# # Prefill only batch. +# output, num_tokens = prefill_output, num_prefill_tokens +# else: +# # Chunked prefill batch does not work with speculative decoding in +# # FlashInfer backend, so the query length for decode should be 1. +# assert prefill_output is not None +# assert decode_output is not None +# assert decode_meta is not None +# assert decode_meta.decode_query_len == 1 +# decode_output = decode_output.squeeze(1) +# output = torch.cat([prefill_output, decode_output], dim=0) +# return output.view(num_tokens, hidden_size) + def unified_flash_infer( query: torch.Tensor, key: torch.Tensor, @@ -1000,7 +1252,6 @@ def unified_flash_infer( alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: - current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashInferMetadata) @@ -1012,7 +1263,6 @@ def unified_flash_infer( value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( key, value, @@ -1023,8 +1273,6 @@ def unified_flash_infer( k_scale, v_scale, ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 if kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype) @@ -1033,50 +1281,47 @@ def unified_flash_infer( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" + query = query.contiguous() # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. + decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - ) - else: - assert prefill_meta is not None - assert prefill_meta.wrapper is not None - prefill_output = prefill_meta.wrapper.run(query, kv_cache) - if decode_meta := attn_metadata.decode_metadata: - assert attn_metadata.decode_metadata is not None - if attn_metadata.decode_metadata.cuda_wrapper is not None: - decode_output = attn_metadata.decode_metadata.cuda_wrapper.forward( + prefill_query = query[:num_prefill_tokens] + + # Profile run case - use flash attention when no kv_cache + if kv_cache.numel() == 0: + return flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ).view(num_tokens, hidden_size) + + # For all non-profile cases, we need a wrapper + assert attn_metadata.wrapper is not None + + + # Case 1: Prefill only + if num_prefill_tokens > 0 and num_decode_tokens == 0: + output = attn_metadata.wrapper.run(prefill_query, kv_cache) + print("PREFILL") + print(output.shape) + return output.view(num_tokens, hidden_size) + + # Case 2: Decode only + if num_prefill_tokens == 0 and num_decode_tokens > 0: + if attn_metadata.cuda_wrapper is not None: + print("CUDA DECODE") + output = attn_metadata.cuda_wrapper.forward( decode_query, kv_cache, sm_scale=softmax_scale, @@ -1084,28 +1329,17 @@ def unified_flash_infer( k_scale=k_scale, v_scale=v_scale) else: - assert attn_metadata.decode_metadata.wrapper is not None - decode_output = attn_metadata.decode_metadata.wrapper.run( - decode_query, kv_cache) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) + print("DECODE") + output = attn_metadata.wrapper.run(decode_query, kv_cache) + print(output.shape) + return output.view(num_tokens, hidden_size) + + # Case 3: Both prefill and decode (chunked prefill case) + print("PREFILL AND DECODE") + output = attn_metadata.wrapper.run(query, kv_cache) + print(output.shape) return output.view(num_tokens, hidden_size) - def unified_flash_infer_fake( query: torch.Tensor, key: torch.Tensor, From 9064ec37be9810264f5adc6a5ad3d9b825248fc6 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Tue, 19 Nov 2024 22:25:47 +0000 Subject: [PATCH 07/12] change layer to level --- vllm/attention/backends/flashinfer.py | 128 +++++++++++++------------- 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index f2a989f22060c..3b83350e5e09f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -309,13 +309,13 @@ class FlashInferMetadata(AttentionMetadata): paged_kv_indptr: Optional[torch.Tensor] = None # paged_kv_last_page_len is the length of the last page of the shared KVs paged_kv_last_page_len: Optional[torch.Tensor] = None - # Store the concatenated page indices of all requests for the second layer - second_layer_kv_indices: Optional[torch.Tensor] = None + # Store the concatenated page indices of all requests for the second level + second_level_kv_indices: Optional[torch.Tensor] = None # Index pointers to the start of each request's page indices - # in the second_layer_kv_indices - second_layer_kv_indptr: Optional[torch.Tensor] = None - # The length of the last page of each request in the second layer - second_layer_kv_last_page_len: Optional[torch.Tensor] = None + # in the second_level_kv_indices + second_level_kv_indptr: Optional[torch.Tensor] = None + # The length of the last page of each request in the second level + second_level_kv_last_page_len: Optional[torch.Tensor] = None # The number of query/output heads num_qo_heads: Optional[int] = None @@ -354,9 +354,9 @@ def __post_init__(self): # assert self.paged_kv_indices is not None # assert self.paged_kv_indptr is not None # assert self.paged_kv_last_page_len is not None - # assert self.second_layer_kv_indices is not None - # assert self.second_layer_kv_indptr is not None - # assert self.second_layer_kv_last_page_len is not None + # assert self.second_level_kv_indices is not None + # assert self.second_level_kv_indptr is not None + # assert self.second_level_kv_last_page_len is not None # assert self.block_table_bound is not None # assert self.seq_lens_tensor is not None # self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] @@ -373,11 +373,11 @@ def __post_init__(self): # self.block_table_bound = self.block_table_bound.to(self.device) # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) # self.paged_kv_indices = self.paged_kv_indices.to(self.device) - # self.second_layer_kv_indices = self.second_layer_kv_indices.to( + # self.second_level_kv_indices = self.second_level_kv_indices.to( # self.device) - # self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + # self.second_level_kv_indptr = self.second_level_kv_indptr.to( # self.device) - # self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + # self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa: E501 # self.device) # # print("Batch Size", batch_size) @@ -385,14 +385,14 @@ def __post_init__(self): # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) # # print("QUERY LOC", self.query_start_loc) # # print("SECOND QUERY LOC", self.second_level_query_start_loc) - # # print("SECOND LAYER PAGED KV INDPTR", self.second_layer_kv_indptr[:self.num_prefills+1]) + # # print("SECOND LAYER PAGED KV INDPTR", self.second_level_kv_indptr[:self.num_prefills+1]) # self.wrapper.plan( # [self.query_start_loc, self.second_level_query_start_loc], - # [self.paged_kv_indptr[:self.num_prefills+1], self.second_layer_kv_indptr[:self.num_prefills+1]], - # [self.paged_kv_indices, self.second_layer_kv_indices], [ + # [self.paged_kv_indptr[:self.num_prefills+1], self.second_level_kv_indptr[:self.num_prefills+1]], + # [self.paged_kv_indices, self.second_level_kv_indices], [ # self.paged_kv_last_page_len[:self.num_prefills], - # self.second_layer_kv_last_page_len[:self.num_prefills] + # self.second_level_kv_last_page_len[:self.num_prefills] # ], # self.num_qo_heads, # self.num_kv_heads, @@ -446,19 +446,19 @@ def __post_init__(self): # assert self.paged_kv_indices is not None # assert self.paged_kv_indptr is not None # assert self.paged_kv_last_page_len is not None - # assert self.second_layer_kv_indices is not None - # assert self.second_layer_kv_indptr is not None - # assert self.second_layer_kv_last_page_len is not None + # assert self.second_level_kv_indices is not None + # assert self.second_level_kv_indptr is not None + # assert self.second_level_kv_last_page_len is not None # self.paged_kv_indices = self.paged_kv_indices.to(self.device) # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( # self.device) - # self.second_layer_kv_indices = self.second_layer_kv_indices.to( + # self.second_level_kv_indices = self.second_level_kv_indices.to( # self.device) - # self.second_layer_kv_indptr = self.second_layer_kv_indptr.to( + # self.second_level_kv_indptr = self.second_level_kv_indptr.to( # self.device) - # self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to( # noqa: E501 + # self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa: E501 # self.device) # # handle model warmup path @@ -472,18 +472,18 @@ def __post_init__(self): # # print("Batch Size", batch_size) # # print("NUM PREFILLS", self.num_prefills) # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) - # # print("SECOND LAYER PAGED KV INDPTR", self.second_layer_kv_indptr[:self.num_prefills+1]) + # # print("SECOND LAYER PAGED KV INDPTR", self.second_level_kv_indptr[:self.num_prefills+1]) # # print("QUERY LOC", self.query_start_loc) # # print("SECOND QUERY LOC", self.second_level_query_start_loc) # # print("PAGED KV INDICES", self.paged_kv_indices) - # # print("SECOND PAGED KV INDICES", self.second_layer_kv_indices) + # # print("SECOND PAGED KV INDICES", self.second_level_kv_indices) # self.wrapper.plan( # [self.query_start_loc, self.second_level_query_start_loc], - # [self.paged_kv_indptr[self.num_prefills:], self.second_layer_kv_indptr[self.num_prefills:]], - # [self.paged_kv_indices, self.second_layer_kv_indices], [ + # [self.paged_kv_indptr[self.num_prefills:], self.second_level_kv_indptr[self.num_prefills:]], + # [self.paged_kv_indices, self.second_level_kv_indices], [ # self.paged_kv_last_page_len[self.num_prefills:], - # self.second_layer_kv_last_page_len[self.num_prefills:] + # self.second_level_kv_last_page_len[self.num_prefills:] # ], # self.num_qo_heads, # self.num_kv_heads, @@ -504,9 +504,9 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] assert self.paged_kv_last_page_len is not None if not self.use_cuda_graph: - assert self.second_layer_kv_indices is not None - assert self.second_layer_kv_indptr is not None - assert self.second_layer_kv_last_page_len is not None + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None print("DO WE USE CUDA GRAPH?", self.use_cuda_graph) # Skip device transfer for profile run @@ -517,9 +517,9 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(self.device) if not self.use_cuda_graph: - self.second_layer_kv_indices = self.second_layer_kv_indices.to(self.device) - self.second_layer_kv_indptr = self.second_layer_kv_indptr.to(self.device) - self.second_layer_kv_last_page_len = self.second_layer_kv_last_page_len.to(self.device) + self.second_level_kv_indices = self.second_level_kv_indices.to(self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to(self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to(self.device) if self.block_table_bound is not None: self.block_table_bound = self.block_table_bound.to(self.device) @@ -532,11 +532,11 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] [self.query_start_loc[:self.num_prefills + 1], self.second_level_query_start_loc[:self.num_prefills + 1]], [self.paged_kv_indptr[:self.num_prefills + 1], - self.second_layer_kv_indptr[:self.num_prefills + 1]], + self.second_level_kv_indptr[:self.num_prefills + 1]], [self.paged_kv_indices, - self.second_layer_kv_indices], + self.second_level_kv_indices], [self.paged_kv_last_page_len[:self.num_prefills], - self.second_layer_kv_last_page_len[:self.num_prefills]], + self.second_level_kv_last_page_len[:self.num_prefills]], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -552,11 +552,11 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] [self.query_start_loc, self.second_level_query_start_loc], [self.paged_kv_indptr[self.num_prefills:], - self.second_layer_kv_indptr[self.num_prefills:]], + self.second_level_kv_indptr[self.num_prefills:]], [self.paged_kv_indices, - self.second_layer_kv_indices], + self.second_level_kv_indices], [self.paged_kv_last_page_len[self.num_prefills:], - self.second_layer_kv_last_page_len[self.num_prefills:]], + self.second_level_kv_last_page_len[self.num_prefills:]], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -588,11 +588,11 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] [self.query_start_loc, self.second_level_query_start_loc], [self.paged_kv_indptr, - self.second_layer_kv_indptr], + self.second_level_kv_indptr], [self.paged_kv_indices, - self.second_layer_kv_indices], + self.second_level_kv_indices], [self.paged_kv_last_page_len, - self.second_layer_kv_last_page_len], + self.second_level_kv_last_page_len], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -698,12 +698,12 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.paged_kv_indptr: List[int] = [0] # The length of the last page of the shared kvs self.paged_kv_last_page_len: List[int] = [] - # Store concatenated page indices of requests for the second layer - self.second_layer_kv_indices: List[int] = [] + # Store concatenated page indices of requests for the second level + self.second_level_kv_indices: List[int] = [] # Index pointers to the start of each request's page indices - self.second_layer_kv_indptr: List[int] = [0] - # The length of the last page of each request in the second layer - self.second_layer_kv_last_page_len: List[int] = [] + self.second_level_kv_indptr: List[int] = [0] + # The length of the last page of each request in the second level + self.second_level_kv_last_page_len: List[int] = [] self.total_blocks = 0 self.is_profile_run: bool = False @@ -783,6 +783,8 @@ def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, """ Updates the unique level kv tensors """ + # if use cuda graph, we need a different logic for handling the tensors + # since this is still single level if use_cuda_graph: self.total_blocks += len(block_table) block_table_bound = seq_len // self.block_size + 1 \ @@ -803,14 +805,14 @@ def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, block_table_bound = (seq_len) // self.block_size + 1 \ if seq_len % self.block_size != 0 \ else (seq_len) // self.block_size - self.second_layer_kv_indices.extend( + self.second_level_kv_indices.extend( block_table[shared_length:block_table_bound]) - self.second_layer_kv_indptr.append(self.second_layer_kv_indptr[-1] + + self.second_level_kv_indptr.append(self.second_level_kv_indptr[-1] + (block_table_bound - shared_length)) last_page_len = (seq_len) % self.block_size if last_page_len == 0: last_page_len = self.block_size - self.second_layer_kv_last_page_len.append(last_page_len) + self.second_level_kv_last_page_len.append(last_page_len) def _update_shared_kv_tensors(self, common_prefix: List[int], batch_size: int) -> None: @@ -989,11 +991,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.paged_kv_last_page_len, device="cpu", dtype=torch.int) second_level_kv_indices_tensor = torch.tensor( - self.second_layer_kv_indices, device="cpu", dtype=torch.int) + self.second_level_kv_indices, device="cpu", dtype=torch.int) second_level_kv_indptr_tensor = torch.tensor( - self.second_layer_kv_indptr, device="cpu", dtype=torch.int) + self.second_level_kv_indptr, device="cpu", dtype=torch.int) second_level_kv_last_page_len_tensor = torch.tensor( - self.second_layer_kv_last_page_len, + self.second_level_kv_last_page_len, device="cpu", dtype=torch.int) @@ -1029,9 +1031,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, - second_layer_kv_indptr=second_level_kv_indptr_tensor, - second_layer_kv_indices=second_level_kv_indices_tensor, - second_layer_kv_last_page_len=second_level_kv_last_page_len_tensor, + second_level_kv_indptr=second_level_kv_indptr_tensor, + second_level_kv_indices=second_level_kv_indices_tensor, + second_level_kv_last_page_len=second_level_kv_last_page_len_tensor, second_level_query_start_loc=second_level_query_start_loc, block_table_bound=block_table_bound_tensor, seq_lens_tensor=seq_lens_tensor, @@ -1313,14 +1315,14 @@ def unified_flash_infer( # Case 1: Prefill only if num_prefill_tokens > 0 and num_decode_tokens == 0: output = attn_metadata.wrapper.run(prefill_query, kv_cache) - print("PREFILL") - print(output.shape) + # print("PREFILL") + # print(output.shape) return output.view(num_tokens, hidden_size) # Case 2: Decode only if num_prefill_tokens == 0 and num_decode_tokens > 0: if attn_metadata.cuda_wrapper is not None: - print("CUDA DECODE") + # print("CUDA DECODE") output = attn_metadata.cuda_wrapper.forward( decode_query, kv_cache, @@ -1329,15 +1331,15 @@ def unified_flash_infer( k_scale=k_scale, v_scale=v_scale) else: - print("DECODE") + # print("DECODE") output = attn_metadata.wrapper.run(decode_query, kv_cache) - print(output.shape) + # print(output.shape) return output.view(num_tokens, hidden_size) # Case 3: Both prefill and decode (chunked prefill case) - print("PREFILL AND DECODE") + # print("PREFILL AND DECODE") output = attn_metadata.wrapper.run(query, kv_cache) - print(output.shape) + # print(output.shape) return output.view(num_tokens, hidden_size) def unified_flash_infer_fake( From 2637e090bbfc002cdd3d9e66278bb33baa371323 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Tue, 19 Nov 2024 22:28:41 +0000 Subject: [PATCH 08/12] remove comment --- vllm/attention/backends/flashinfer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3b83350e5e09f..06031929f6849 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -250,7 +250,7 @@ def begin_forward(self, model_input, model): state = self try: - scale = getattr(model.model.layers[0].self_attn.attn.impl, "scale", + scale = getattr(model.model.begs[0].self_attn.attn.impl, "scale", None) except AttributeError as e: raise AttributeError("Failed to retrieve 'scale'. \ @@ -508,7 +508,6 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] assert self.second_level_kv_indptr is not None assert self.second_level_kv_last_page_len is not None - print("DO WE USE CUDA GRAPH?", self.use_cuda_graph) # Skip device transfer for profile run if not self.is_profile_run: # Move tensors to device From 29950bc11be01b0e310453989b1d33adf0bdd425 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Tue, 19 Nov 2024 23:26:27 +0000 Subject: [PATCH 09/12] typo --- vllm/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 06031929f6849..b059569d446a8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -250,7 +250,7 @@ def begin_forward(self, model_input, model): state = self try: - scale = getattr(model.model.begs[0].self_attn.attn.impl, "scale", + scale = getattr(model.model.layers[0].self_attn.attn.impl, "scale", None) except AttributeError as e: raise AttributeError("Failed to retrieve 'scale'. \ From 555de935e6ced817425c2521000e1e10dda3f4ca Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Fri, 22 Nov 2024 21:17:31 +0000 Subject: [PATCH 10/12] formatting --- vllm/attention/backends/flashinfer.py | 522 +++++++------------------- 1 file changed, 138 insertions(+), 384 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b059569d446a8..e09339e1b1e10 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -342,158 +342,8 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - #TODO: NEED TO ADD CHUNKED PREFILL - # def begin_forward(self, scale: Optional[float], - # logits_soft_cap: Optional[float]): - # if self.num_prefill_tokens > 0: - # if self.paged_kv_indices is None: - # return - - # assert self.wrapper is not None - # assert self.query_start_loc is not None - # assert self.paged_kv_indices is not None - # assert self.paged_kv_indptr is not None - # assert self.paged_kv_last_page_len is not None - # assert self.second_level_kv_indices is not None - # assert self.second_level_kv_indptr is not None - # assert self.second_level_kv_last_page_len is not None - # assert self.block_table_bound is not None - # assert self.seq_lens_tensor is not None - # self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - # self.second_level_query_start_loc = self.second_level_query_start_loc[:self.num_prefills+1] - # batch_size = self.query_start_loc.shape[0] - 1 - # assert batch_size >= 0 - # # We will use flash attention for profiling to - # # determine the number of blocks. Therefore, - # # we don't need to prepare the input for flashinfer for profile run. - # if not self.is_profile_run: - # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - # self.device) - # self.block_table_bound = self.block_table_bound.to(self.device) - # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - # self.paged_kv_indices = self.paged_kv_indices.to(self.device) - # self.second_level_kv_indices = self.second_level_kv_indices.to( - # self.device) - # self.second_level_kv_indptr = self.second_level_kv_indptr.to( - # self.device) - # self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa: E501 - # self.device) - - # # print("Batch Size", batch_size) - # # print("NUM PREFILLS", self.num_prefills) - # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) - # # print("QUERY LOC", self.query_start_loc) - # # print("SECOND QUERY LOC", self.second_level_query_start_loc) - # # print("SECOND LAYER PAGED KV INDPTR", self.second_level_kv_indptr[:self.num_prefills+1]) - - # self.wrapper.plan( - # [self.query_start_loc, self.second_level_query_start_loc], - # [self.paged_kv_indptr[:self.num_prefills+1], self.second_level_kv_indptr[:self.num_prefills+1]], - # [self.paged_kv_indices, self.second_level_kv_indices], [ - # self.paged_kv_last_page_len[:self.num_prefills], - # self.second_level_kv_last_page_len[:self.num_prefills] - # ], - # self.num_qo_heads, - # self.num_kv_heads, - # self.head_dim, - # self.page_size, - # causal=True, - # sm_scale=scale, - # logits_soft_cap=logits_soft_cap) - - # if self.num_decode_tokens > 0: - # if self.cuda_wrapper: - # assert self.paged_kv_indices is not None - # assert self.paged_kv_indptr is not None - # assert self.paged_kv_last_page_len is not None - # self.paged_kv_indices = self.paged_kv_indices.to(self.device) - # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - # self.device) - # # handle model warmup path - # if self.block_table_bound is not None: - # self.block_table_bound = self.block_table_bound.to( - # self.device) - # if self.seq_lens_tensor is not None: - # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - # # print("Batch Size", batch_size) - # # print("NUM PREFILLS", self.num_prefills) - # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) - # # print("QUERY LOC", self.query_start_loc) - # # print("PAGED KV INDICES", self.paged_kv_indices) - - # assert self.cuda_wrapper is not None - # self.cuda_wrapper.end_forward() - # self.cuda_wrapper.begin_forward( - # self.paged_kv_indptr[self.num_prefills:], - # self.paged_kv_indices, - # self.paged_kv_last_page_len[self.num_prefills:], - # self.num_qo_heads, - # self.num_kv_heads, - # self.head_dim, - # self.page_size, - # # Disable flashinfer's pos encoding and use vllm's rope. - # pos_encoding_mode="NONE", - # # kv-cache data type. - # data_type=self.data_type, - # # query data type. - # q_data_type=self.q_data_type) - - # return - - # assert self.paged_kv_indices is not None - # assert self.paged_kv_indptr is not None - # assert self.paged_kv_last_page_len is not None - # assert self.second_level_kv_indices is not None - # assert self.second_level_kv_indptr is not None - # assert self.second_level_kv_last_page_len is not None - - # self.paged_kv_indices = self.paged_kv_indices.to(self.device) - # self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - # self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - # self.device) - # self.second_level_kv_indices = self.second_level_kv_indices.to( - # self.device) - # self.second_level_kv_indptr = self.second_level_kv_indptr.to( - # self.device) - # self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa: E501 - # self.device) - - # # handle model warmup path - # if self.block_table_bound is not None: - # self.block_table_bound = self.block_table_bound.to(self.device) - # if self.seq_lens_tensor is not None: - # self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - # assert self.wrapper is not None - - # # print("Batch Size", batch_size) - # # print("NUM PREFILLS", self.num_prefills) - # # print("FIRST LEVEL PAGED KV INDPTR", self.paged_kv_indptr[:self.num_prefills+1]) - # # print("SECOND LAYER PAGED KV INDPTR", self.second_level_kv_indptr[:self.num_prefills+1]) - # # print("QUERY LOC", self.query_start_loc) - # # print("SECOND QUERY LOC", self.second_level_query_start_loc) - # # print("PAGED KV INDICES", self.paged_kv_indices) - # # print("SECOND PAGED KV INDICES", self.second_level_kv_indices) - - # self.wrapper.plan( - # [self.query_start_loc, self.second_level_query_start_loc], - # [self.paged_kv_indptr[self.num_prefills:], self.second_level_kv_indptr[self.num_prefills:]], - # [self.paged_kv_indices, self.second_level_kv_indices], [ - # self.paged_kv_last_page_len[self.num_prefills:], - # self.second_level_kv_last_page_len[self.num_prefills:] - # ], - # self.num_qo_heads, - # self.num_kv_heads, - # self.head_dim, - # self.page_size, - # causal=True, - # sm_scale=scale, - # logits_soft_cap=logits_soft_cap) - - def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float]): + def begin_forward(self, scale: Optional[float], + logits_soft_cap: Optional[float]): if self.paged_kv_indices is None: return @@ -503,66 +353,79 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None - if not self.use_cuda_graph: - assert self.second_level_kv_indices is not None - assert self.second_level_kv_indptr is not None - assert self.second_level_kv_last_page_len is not None - - # Skip device transfer for profile run if not self.is_profile_run: - # Move tensors to device self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) - if not self.use_cuda_graph: - self.second_level_kv_indices = self.second_level_kv_indices.to(self.device) - self.second_level_kv_indptr = self.second_level_kv_indptr.to(self.device) - self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to(self.device) - - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + if self.num_decode_tokens > 0: + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to( + self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) # Case 1: Prefill only if self.num_prefill_tokens > 0 and self.num_decode_tokens == 0: - self.wrapper.plan( - [self.query_start_loc[:self.num_prefills + 1], - self.second_level_query_start_loc[:self.num_prefills + 1]], - [self.paged_kv_indptr[:self.num_prefills + 1], - self.second_level_kv_indptr[:self.num_prefills + 1]], - [self.paged_kv_indices, - self.second_level_kv_indices], - [self.paged_kv_last_page_len[:self.num_prefills], - self.second_level_kv_last_page_len[:self.num_prefills]], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=scale, - logits_soft_cap=logits_soft_cap) - + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + assert self.second_level_query_start_loc is not None + assert self.query_start_loc is not None + + self.second_level_kv_indices = self.second_level_kv_indices.to( # noqa + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( # noqa + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + self.wrapper.plan([ + self.query_start_loc[:self.num_prefills + 1], + self.second_level_query_start_loc[:self.num_prefills + 1] + ], [ + self.paged_kv_indptr[:self.num_prefills + 1], + self.second_level_kv_indptr[:self.num_prefills + 1] + ], [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len[:self.num_prefills], + self.second_level_kv_last_page_len[:self.num_prefills] + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) + # Case 2: Decode only elif self.num_prefill_tokens == 0 and self.num_decode_tokens > 0: if not self.use_cuda_graph: - self.wrapper.plan( - [self.query_start_loc, - self.second_level_query_start_loc], - [self.paged_kv_indptr[self.num_prefills:], - self.second_level_kv_indptr[self.num_prefills:]], - [self.paged_kv_indices, - self.second_level_kv_indices], - [self.paged_kv_last_page_len[self.num_prefills:], - self.second_level_kv_last_page_len[self.num_prefills:]], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=scale, - logits_soft_cap=logits_soft_cap) + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + self.second_level_kv_indices = self.second_level_kv_indices.to( # noqa + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( # noqa + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + self.wrapper.plan([ + self.query_start_loc, self.second_level_query_start_loc + ], [ + self.paged_kv_indptr[self.num_prefills:], + self.second_level_kv_indptr[self.num_prefills:] + ], [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len[self.num_prefills:], + self.second_level_kv_last_page_len[self.num_prefills:] + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) else: assert self.cuda_wrapper is not None self.cuda_wrapper.end_forward() @@ -579,26 +442,33 @@ def begin_forward(self, scale: Optional[float], logits_soft_cap: Optional[float] # kv-cache data type. data_type=self.data_type, # query data type. - q_data_type=self.q_data_type - ) + q_data_type=self.q_data_type) # Case 3: Both prefill and decode (chunked prefill case) else: + assert self.second_level_kv_indices is not None + assert self.second_level_kv_indptr is not None + assert self.second_level_kv_last_page_len is not None + self.second_level_kv_indices = self.second_level_kv_indices.to( + self.device) + self.second_level_kv_indptr = self.second_level_kv_indptr.to( + self.device) + self.second_level_kv_last_page_len = self.second_level_kv_last_page_len.to( # noqa + self.device) + self.wrapper.plan( - [self.query_start_loc, - self.second_level_query_start_loc], - [self.paged_kv_indptr, - self.second_level_kv_indptr], - [self.paged_kv_indices, - self.second_level_kv_indices], - [self.paged_kv_last_page_len, - self.second_level_kv_last_page_len], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=scale, - logits_soft_cap=logits_soft_cap) + [self.query_start_loc, self.second_level_query_start_loc], + [self.paged_kv_indptr, self.second_level_kv_indptr], + [self.paged_kv_indices, self.second_level_kv_indices], [ + self.paged_kv_last_page_len, + self.second_level_kv_last_page_len + ], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=scale, + logits_soft_cap=logits_soft_cap) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -773,31 +643,36 @@ def _add_seq_group( return block_table = block_tables[seq_id] - self._update_unique_kv_tensors(block_table, seq_len, common_prefix, - use_cuda_graph) + if use_cuda_graph: + self._update_cuda_wrapper_unique_kv_tensors( + block_table, seq_len) + else: + self._update_unique_kv_tensors(block_table, seq_len, + common_prefix) - def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, - common_prefix: List[int], - use_cuda_graph: bool) -> None: + def _update_cuda_wrapper_unique_kv_tensors(self, block_table: List[int], + seq_len: int) -> None: """ - Updates the unique level kv tensors + Updates tensors for cuda decode wrapper """ - # if use cuda graph, we need a different logic for handling the tensors - # since this is still single level - if use_cuda_graph: - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) - return + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def _update_unique_kv_tensors(self, block_table: List[int], seq_len: int, + common_prefix: List[int]) -> None: + """ + Updates the unique level tensors + """ shared_length = len(common_prefix) self.total_blocks += (len(block_table) - shared_length) @@ -826,7 +701,6 @@ def _update_shared_kv_tensors(self, common_prefix: List[int], self.paged_kv_indptr.extend([0] * batch_size) self.paged_kv_last_page_len.extend([0] * batch_size) else: - # FIXME: this still doesn't work self.total_blocks += len(common_prefix) self.paged_kv_indices.extend(common_prefix) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + @@ -1116,127 +990,6 @@ def forward( ) -# def unified_flash_infer( -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# num_heads: int, -# head_size: int, -# num_kv_heads: int, -# kv_cache: torch.Tensor, -# kv_cache_dtype: str, -# k_scale: float, -# v_scale: float, -# softmax_scale: float, -# window_size: Optional[List[int]] = None, -# alibi_slopes: Optional[torch.Tensor] = None, -# logits_soft_cap: Optional[float] = None, -# ) -> torch.Tensor: - -# current_metadata = get_forward_context() -# assert current_metadata is not None -# assert isinstance(current_metadata, FlashInferMetadata) -# attn_metadata: FlashInferMetadata = current_metadata - -# num_tokens, hidden_size = query.shape -# query = query.view(-1, num_heads, head_size) -# key = key.view(-1, num_kv_heads, head_size) -# value = value.view(-1, num_kv_heads, head_size) - -# if kv_cache.numel() > 0: -# # Use the same reshape and cache kernel as flash attention. -# ops.reshape_and_cache_flash( -# key, -# value, -# kv_cache[:, 0], -# kv_cache[:, 1], -# attn_metadata.slot_mapping.flatten(), -# kv_cache_dtype, -# k_scale, -# v_scale, -# ) -# # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 -# # to process the cache when the kv_cache_dtype is fp8 -# if kv_cache_dtype.startswith("fp8"): -# torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( -# kv_cache_dtype) -# kv_cache = kv_cache.view(torch_dtype) - -# num_prefill_tokens = attn_metadata.num_prefill_tokens -# num_decode_tokens = attn_metadata.num_decode_tokens -# assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ -# f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa -# assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ -# f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa -# query = query.contiguous() # Flashinfer requires query to be contiguous -# # Query for decode. KV is not needed because it is already cached. -# # QKV for prefill. -# decode_query = query[num_prefill_tokens:] -# query = query[:num_prefill_tokens] - -# key = key[:num_prefill_tokens] -# value = value[:num_prefill_tokens] - -# assert query.shape[0] == num_prefill_tokens -# assert decode_query.shape[0] == num_decode_tokens - -# prefill_output: Optional[torch.Tensor] = None -# decode_output: Optional[torch.Tensor] = None -# if prefill_meta := attn_metadata.prefill_metadata: -# # We will use flash attention for prefill -# # when kv_cache is not provided. -# # This happens when vllm runs the profiling to -# # determine the number of blocks. -# if kv_cache.numel() == 0: -# prefill_output = flash_attn_varlen_func( -# q=query, -# k=key, -# v=value, -# cu_seqlens_q=prefill_meta.seq_start_loc, -# cu_seqlens_k=prefill_meta.seq_start_loc, -# max_seqlen_q=prefill_meta.max_prefill_seq_len, -# max_seqlen_k=prefill_meta.max_prefill_seq_len, -# softmax_scale=softmax_scale, -# causal=True, -# window_size=window_size, -# alibi_slopes=alibi_slopes, -# ) -# else: -# assert prefill_meta is not None -# assert prefill_meta.wrapper is not None -# prefill_output = prefill_meta.wrapper.run(query, kv_cache) -# if decode_meta := attn_metadata.decode_metadata: -# assert attn_metadata.decode_metadata is not None -# if attn_metadata.decode_metadata.cuda_wrapper is not None: -# decode_output = attn_metadata.decode_metadata.cuda_wrapper.forward( -# decode_query, -# kv_cache, -# sm_scale=softmax_scale, -# logits_soft_cap=logits_soft_cap, -# k_scale=k_scale, -# v_scale=v_scale) -# else: -# assert attn_metadata.decode_metadata.wrapper is not None -# decode_output = attn_metadata.decode_metadata.wrapper.run( -# decode_query, kv_cache) - -# if prefill_output is None and decode_output is not None: -# # Decode only batch. -# output, num_tokens = decode_output, num_decode_tokens -# elif decode_output is None and prefill_output is not None: -# # Prefill only batch. -# output, num_tokens = prefill_output, num_prefill_tokens -# else: -# # Chunked prefill batch does not work with speculative decoding in -# # FlashInfer backend, so the query length for decode should be 1. -# assert prefill_output is not None -# assert decode_output is not None -# assert decode_meta is not None -# assert decode_meta.decode_query_len == 1 -# decode_output = decode_output.squeeze(1) -# output = torch.cat([prefill_output, decode_output], dim=0) -# return output.view(num_tokens, hidden_size) - def unified_flash_infer( query: torch.Tensor, key: torch.Tensor, @@ -1282,16 +1035,23 @@ def unified_flash_infer( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" - - query = query.contiguous() # Flashinfer requires query to be contiguous + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous() # Flashinfer requires query to be contiguous + # Query for decode and prefill. + # KV is not needed because it is already cached. + # QKV for prefill. decode_query = query[num_prefill_tokens:] prefill_query = query[:num_prefill_tokens] - - # Profile run case - use flash attention when no kv_cache + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert prefill_query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + if kv_cache.numel() == 0: return flash_attn_varlen_func( q=query, @@ -1307,21 +1067,13 @@ def unified_flash_infer( alibi_slopes=alibi_slopes, ).view(num_tokens, hidden_size) - # For all non-profile cases, we need a wrapper assert attn_metadata.wrapper is not None - - # Case 1: Prefill only if num_prefill_tokens > 0 and num_decode_tokens == 0: output = attn_metadata.wrapper.run(prefill_query, kv_cache) - # print("PREFILL") - # print(output.shape) return output.view(num_tokens, hidden_size) - - # Case 2: Decode only - if num_prefill_tokens == 0 and num_decode_tokens > 0: + elif num_prefill_tokens == 0 and num_decode_tokens > 0: if attn_metadata.cuda_wrapper is not None: - # print("CUDA DECODE") output = attn_metadata.cuda_wrapper.forward( decode_query, kv_cache, @@ -1330,16 +1082,18 @@ def unified_flash_infer( k_scale=k_scale, v_scale=v_scale) else: - # print("DECODE") + assert attn_metadata.wrapper is not None output = attn_metadata.wrapper.run(decode_query, kv_cache) - # print(output.shape) + return output.view(num_tokens, hidden_size) + else: + # Ensure chunked prefill with speculative decoding is not allowed + assert decode_query.shape[0] == 1, \ + """Chunked prefill batch does not work with + speculative decoding in FlashInfer backend.""" + + output = attn_metadata.wrapper.run(query, kv_cache) return output.view(num_tokens, hidden_size) - # Case 3: Both prefill and decode (chunked prefill case) - # print("PREFILL AND DECODE") - output = attn_metadata.wrapper.run(query, kv_cache) - # print(output.shape) - return output.view(num_tokens, hidden_size) def unified_flash_infer_fake( query: torch.Tensor, From 3616ac63980e7a396355fd7cc115a72507f8f3a6 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Fri, 22 Nov 2024 22:05:12 +0000 Subject: [PATCH 11/12] sliding window --- vllm/attention/backends/flashinfer.py | 29 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e09339e1b1e10..bb5c4b00e4474 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -229,7 +229,7 @@ def graph_capture_get_metadata_for_batch( wrapper=self._wrapper) # we don't need to pass logits and scale to begin_forward # since in forward, it already gets it. - attn_metadata.begin_forward(None, None) + attn_metadata.begin_forward(None, None, (-1, -1)) return attn_metadata def get_graph_input_buffers(self, @@ -249,6 +249,10 @@ def begin_forward(self, model_input, model): assert not self._is_graph_capturing state = self + # sliding window needed for new kernel in begin_forward + sliding_window = self.runner.sliding_window + window_left = sliding_window[0] if sliding_window is not None else -1 + try: scale = getattr(model.model.layers[0].self_attn.attn.impl, "scale", None) @@ -272,7 +276,7 @@ def begin_forward(self, model_input, model): model_input.attn_metadata.cuda_wrapper = state._get_cuda_wrapper() model_input.attn_metadata.wrapper = state._get_wrapper() - model_input.attn_metadata.begin_forward(scale, logits_soft_cap) + model_input.attn_metadata.begin_forward(scale, logits_soft_cap, window_left) @dataclass @@ -343,7 +347,7 @@ def __post_init__(self): f"received {self.head_dim}.") def begin_forward(self, scale: Optional[float], - logits_soft_cap: Optional[float]): + logits_soft_cap: Optional[float], window_left: Optional[int]): if self.paged_kv_indices is None: return @@ -396,7 +400,8 @@ def begin_forward(self, scale: Optional[float], self.page_size, causal=True, sm_scale=scale, - logits_soft_cap=logits_soft_cap) + logits_soft_cap=logits_soft_cap, + window_left=window_left) # Case 2: Decode only elif self.num_prefill_tokens == 0 and self.num_decode_tokens > 0: @@ -425,7 +430,8 @@ def begin_forward(self, scale: Optional[float], self.page_size, causal=True, sm_scale=scale, - logits_soft_cap=logits_soft_cap) + logits_soft_cap=logits_soft_cap, + window_left=window_left) else: assert self.cuda_wrapper is not None self.cuda_wrapper.end_forward() @@ -468,7 +474,8 @@ def begin_forward(self, scale: Optional[float], self.page_size, causal=True, sm_scale=scale, - logits_soft_cap=logits_soft_cap) + logits_soft_cap=logits_soft_cap, + window_left=window_left) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -946,9 +953,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap @@ -1052,6 +1058,8 @@ def unified_flash_infer( assert prefill_query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + window_left = window_size[0] if window_size is not None else -1 + if kv_cache.numel() == 0: return flash_attn_varlen_func( q=query, @@ -1080,7 +1088,8 @@ def unified_flash_infer( sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, k_scale=k_scale, - v_scale=v_scale) + v_scale=v_scale, + window_left=window_left) else: assert attn_metadata.wrapper is not None output = attn_metadata.wrapper.run(decode_query, kv_cache) From 81d3f438e10f3cf3c6a5384ea922f1b21dc24da0 Mon Sep 17 00:00:00 2001 From: Ray Wan Date: Thu, 28 Nov 2024 19:12:54 -0800 Subject: [PATCH 12/12] add warning for logits and scale --- vllm/attention/backends/flashinfer.py | 28 +++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1b0b41643e0a6..7529b5cd6b178 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap try: @@ -31,6 +32,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) +logger = init_logger(__name__) + if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) @@ -253,20 +256,21 @@ def begin_forward(self, model_input, model): window_left = sliding_window[0] if sliding_window is not None else -1 try: - scale = getattr(model.model.layers[0].self_attn.attn.impl, "scale", - None) - except AttributeError as e: - raise AttributeError("Failed to retrieve 'scale'. \ - Check if 'self_attn.attn.impl' contains 'scale'.") from e + scale = model.model.layers[0].self_attn.attn.impl.scale + except AttributeError: + scale = None + logger.warning("Failed to retrieve 'scale'. \ + Check if 'self_attn.attn.impl' contains 'scale'.\ + Using default value of None") try: - logits_soft_cap = getattr( - model.model.layers[0].self_attn.attn.impl, "logits_soft_cap", - None) - except AttributeError as e: - raise AttributeError("Failed to retrieve 'logits_soft_cap'. \ - Check if 'self_attn.attn.impl' contains 'logits_soft_cap'." - ) from e + logits_soft_cap = model.model.layers[ + 0].self_attn.attn.impl.logits_soft_cap + except AttributeError: + logits_soft_cap = None + logger.warning("Failed to retrieve 'logits_soft_cap'. \ + Check if 'self_attn.attn.impl' contains 'logits_soft_cap'. \ + Using default value of None") if model_input.attn_metadata.use_cuda_graph: batch_size = model_input.input_tokens.shape[0]