diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 65be3c5d93b20..380ad0cd1903b 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -23,22 +23,22 @@ usage() { while getopts "m:b:l:f:t:" OPT; do case ${OPT} in - m ) + m ) MODEL="$OPTARG" ;; - b ) + b ) BATCH_SIZE="$OPTARG" ;; - l ) + l ) LIMIT="$OPTARG" ;; - f ) + f ) FEWSHOT="$OPTARG" ;; t ) TP_SIZE="$OPTARG" ;; - \? ) + \? ) usage exit 1 ;; @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096,enforce_eager=true" \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ --batch_size "$BATCH_SIZE" diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index 23cc6e8539431..cb5ee8d72c836 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -8,10 +8,25 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=512, +) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM( + model="deepseek-ai/DeepSeek-V2-Lite-Chat", + # model="deepseek-ai/DeepSeek-V2.5", + tensor_parallel_size=1, + trust_remote_code=True, + max_model_len=4096, + # dtype="float16", + enforce_eager=True, + # max_num_seqs=1, + # block_size=128, + # disable_mla=True, +) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -19,4 +34,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py new file mode 100644 index 0000000000000..ac0281032f68c --- /dev/null +++ b/vllm/attention/backends/flashinfer_mla.py @@ -0,0 +1,720 @@ +import math +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import cached_property +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeMlaWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch +from vllm_flash_attn import flash_attn_varlen_func + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class FlashInferMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> Type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferMLAState"]: + return FlashInferMLAState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + # NOTE(simon): we repurpose the "key" cache for latent, + # and "value" cache for rope. Until we have hybrid memory + # allocate, we are living with some memory waste. + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [512] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +class FlashInferMLAState(AttentionState): + + def __init__(self, runner): + self.runner = runner + + @cached_property + def _workspace_buffer(self): + return torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + + @cached_property + def _decode_wrapper(self): + return BatchDecodeMlaWithPagedKVCacheWrapper(self._workspace_buffer) + + @contextmanager + def graph_capture(self, max_batch_size: int): + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") + + def graph_clone(self, batch_size: int): + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") + + def begin_forward(self, model_input): + model_input.attn_metadata.decode_wrapper = self._decode_wrapper + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMLAMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + # Note(simon): we are using Flash Attention for prefill so we don't need a + # wrapper. However, it can be replaced with a + # BatchPrefillWithRaggedKVCacheWrapper implementation. + decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + device: torch.device = torch.device("cuda") + is_profile_run: bool = False + + sm_scale: float = 0.0 + extras: Dict[str, torch.Tensor] = field(default_factory=dict) + + def __post_init__(self): + supported_head_sizes = FlashInferMLABackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # Note(simon): for MLA: soft max scale needs to be + # `1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)`. + assert self.head_dim is not None + self.sm_scale = 1.0 / math.sqrt(self.head_dim + self.head_dim // 8) + + def begin_forward(self): + if self.num_prefill_tokens > 0: + return + + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.head_dim, + self.page_size, + sm_scale=self.sm_scale, + data_type=self.data_type, + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMLAMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMLAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + raise NotImplementedError( + "FlashInferMLAMetadata does not support multi-step") + + +class FlashInferMLAMetadataBuilder( + AttentionMetadataBuilder[FlashInferMLAMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMLAMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run) + + +class FlashInferMLAImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for FlashInferMLAImpl") + + if attn_metadata.prefill_metadata is not None: + return self._forward_prefill(query, key, value, kv_cache, + attn_metadata, k_scale, v_scale) + + if attn_metadata.decode_metadata is not None: + return self._forward_decode(query, key, value, kv_cache, + attn_metadata, k_scale, v_scale) + + def _forward_prefill( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + + kv_a = attn_metadata.extras["kv_a"] + k_pe = attn_metadata.extras["k_pe"] + + # write the latent and rope to kv cache + # TODO(simon): remove the hard code, k_pe is assumed to be 1/8 of the + # latent size. + assert k_pe.shape[-1] == self.head_size // 8 + to_cache_key_rope = torch.nn.functional.pad( + k_pe, [0, self.head_size - self.head_size // 8], value=0) + if kv_cache.numel() > 0: + ops.reshape_and_cache_flash( + kv_a, + to_cache_key_rope, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + + # run prefill without paged kv cache. + q = torch.nn.functional.pad(query, [0, 256 - query.shape[-1]], value=0) + k = torch.nn.functional.pad(key, [0, 256 - key.shape[-1]], value=0) + v = torch.nn.functional.pad(value, [0, 256 - value.shape[-1]], value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) + attn_output = attn_output.view(-1, self.num_heads, + 256)[..., :value.shape[-1]].reshape( + -1, + self.num_heads * value.shape[-1]) + return attn_output + + def _forward_decode( + self, + query: torch.Tensor, + key: torch.Tensor, + rope: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + assert kv_cache.numel() > 0 + + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key.contiguous(), + rope.contiguous(), + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + decode_query_nope = query[:, :, :self.head_size].contiguous() + decode_query_pe = query[:, :, self.head_size:].contiguous() + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + paged_kpe_cache = kv_cache[:, 1] + paged_kpe_cache = paged_kpe_cache[..., :64] + + decode_meta.decode_wrapper._sm_scale = self.scale + decode_output = decode_meta.decode_wrapper.run( + q_nope=decode_query_nope, + q_pe=decode_query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=kv_cache[:, 1], + ) + return decode_output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9b03fd73fe690..7d7876fbeed75 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -40,6 +40,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -91,9 +92,13 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, - block_size, is_attention_free, - blocksparse_params is not None) + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 81ea6eefb5410..5600374b02ef2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -83,6 +83,7 @@ def get_attn_backend( block_size: int, is_attention_free: bool, is_blocksparse: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -97,6 +98,7 @@ def get_attn_backend( is_attention_free=is_attention_free, is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, ) @@ -109,6 +111,7 @@ def _cached_get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_v1: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") @@ -141,7 +144,8 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( - selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/config.py b/vllm/config.py index 79754bd04102f..d1f793b370f20 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -74,6 +74,20 @@ PretrainedConfig]] +def _is_flashinfer_available() -> bool: + """Check if FlashInfer is available. + + Returns: + bool: True if FlashInfer is installed and available, False otherwise. + """ + try: + from flashinfer import ( # noqa:F401 + BatchDecodeMlaWithPagedKVCacheWrapper) + return True + except ImportError: + return False + + class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -169,6 +183,7 @@ class ModelConfig: `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. + disable_mla: Whether to disable MLA for DeepSeek models. """ def compute_hash(self) -> str: @@ -195,40 +210,43 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__(self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + disable_mla: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -277,6 +295,7 @@ def __init__(self, self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init + self.disable_mla = disable_mla hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) @@ -730,17 +749,26 @@ def get_vocab_size(self) -> int: def get_hidden_size(self) -> int: return self.hf_text_config.hidden_size + @property + def is_deepseek_v2(self) -> bool: + return hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + in ('deepseek_v2', 'deepseek_v3')) + def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - in ('deepseek_v2', 'deepseek_v3')): - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", - 0) - qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", - 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim + if self.is_deepseek_v2: + # FlashAttention supports only head_size 32, 64, 128, 256, + # we need to pad head_size 192 to 256 + if self.should_use_mla: + return self.hf_text_config.kv_lora_rank + else: + qk_rope_head_dim = getattr(self.hf_text_config, + "qk_rope_head_dim", 0) + qk_nope_head_dim = getattr(self.hf_text_config, + "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim if self.is_attention_free: return 0 @@ -799,6 +827,10 @@ def get_total_num_kv_heads(self) -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" + if self.should_use_mla: + # TODO(simon): feature flag MLA + return 1 + total_num_kv_heads = self.get_total_num_kv_heads() # If tensor parallelism is used, we divide the number of KV heads by # the tensor parallel size. We will replicate the KV heads in the @@ -939,6 +971,28 @@ def is_cross_encoder(self) -> bool: return ModelRegistry.is_cross_encoder_model(architectures) @property + def should_use_mla(self) -> bool: + """Whether MLA should be used for this model. + + Returns True if: + 1. The model is DeepSeek V2 + 2. MLA is not explicitly disabled + 3. FlashInfer is available + + If conditions 1 and 2 are met but FlashInfer is not available, + logs a warning and returns False. + """ + use_mla = (self.is_deepseek_v2 and not self.disable_mla + and not envs.VLLM_DISABLE_MLA) + if use_mla and not _is_flashinfer_available(): + logger.warning( + "Please install or update FlashInfer for better performance on " + "DeepSeek model via enabling MLA. See " + "https://github.com/flashinfer-ai/flashinfer for installation." + ) + return False + return use_mla + def supported_runner_types(self) -> Set[RunnerType]: return {_TASK_RUNNER[task] for task in self.supported_tasks} diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4f4c9558d056..b74dfce10d71b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -102,6 +102,7 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + disable_mla: bool = False # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -944,7 +945,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default="auto", help='The worker class to use for distributed execution.') - + parser.add_argument('--disable-mla', + action='store_true', + help='Disable MLA for DeepSeek models.') parser.add_argument( "--generation-config", type=nullable_str, @@ -998,6 +1001,7 @@ def create_model_config(self) -> ModelConfig: disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, + disable_mla=self.disable_mla, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config) diff --git a/vllm/envs.py b/vllm/envs.py index b7b597ea15af3..adad53cc31449 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -285,6 +285,10 @@ def get_default_config_root(): "VLLM_FLASHINFER_FORCE_TENSOR_CORES": lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), + # If set, vLLM will disable the MLA attention optimizations. + "VLLM_DISABLE_MLA": + lambda: bool(int(os.getenv("VLLM_DISABLE_MLA", "0"))), + # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index af6810a140b43..829c62d21894c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -326,12 +326,345 @@ def forward( return output +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://github.com/flashinfer-ai/flashinfer/pull/551). + + Deepseek's MLA attention works the following way: + * Use a single latent vector to represent the entire KV cache. + * The attention "simulates" a multi-head attention, while the compute is + similar to multi-query attention. + * The dataflow is as follows, + + * B: batch/sequence length + * H: hidden size + * N: number of attention heads + * Lq: latent dimension for Q + * Lkv: latent dimension for K/V + * P: nope dimension, P+R is the actual head_dim in common attention. + * R: rope dimension, this slide of the head_dim goes through rope. + * V: V head dim. + + # The reconstructed way, as implemented in DeepseekV2Attention: + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and + kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent + and kv_a are normalized. + 3. The q_latent and kv_a are then projected up into the multi-head + version. q_latent goes from (B, Lq) to (B, N(P+R)) included the rope + dimension, which is split into q_nope (B, N, P) and q_pe (B, N, R). + kv_a goes from (B, Lkv) to (B, N(P+V)) which has the nope dimensions + for K and V, which is split into k_nope (B, N, P) and v (B, N, V). + 3. q_pe, k_pe are then passed through rotary embeddings. + 4. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from + q_nope, q_pe, k_nope, k_pe. + 5. Attention is computued with q, k, v. + 6. The KV cache is updated with the new entries k (B, N, (P+R)) and v + (B, N, V), we pad the head dim to 256 so that the KV cache has + consistent shape and works with a typical cache implementation. + 7. The attention computation returns (B, N, V), which is projected back + to (B, H) using out projection. + + # The recommended way, as described in the paper: + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and + kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent + and kv_a are normalized. + 3. Here's the change, we do not perform up the full up projection for + q_latent, and there is no up projection at all for kv_a. This is + achieved by the technique of "weight absorption". The paper says + "Fortunately, due to the associative law of matrix multiplication, + we can absorb WUK into WUQ, and WUV into WO" + * The q up projection turns (B, Lq) into (B, N(P+R)), we split it + into W_UQ (Lq, N, P) and W_QR (Lq, N, R). + * The kv_a up projection turns (B, Lkv) into (B, N(P+V)), we split it + into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). + * The out projection shape W_O (V, H)turns (B, N, V) into (B, H). + * We can precompute the product of W_UQ and W_UK into + W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in + attention. + * We can precompute the product of W_UV and W_O into + W_UV_O (N, Lkv, H), which is possible due to V@O as the + "epilogue" of attention + 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. + The rotary embeddingss still need to be applied to q_pe and k_pe. + 5. By applying W_UQ_UK to q_latent, we have the new q_nope of shape + (B, N, Lkv). + 6. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, + kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. + 6. The attention is computed with q, k, v. Note that we just performed + a MQA attention with (LKv+R) as our head dim. + 7. The KV cache is updated using the new entries k (B, N, (Lkv+R)), + which included the v and rope values. + 8. The attention computation returns (B, N, Lkv), which is projected + back to (B, H) using W_UV_O. + + From @tsu-bin's calculation, we only want to use the absorption technique + for decode. The prefill algorithm should still use the up-projected MHA + for less flops and memory usage. + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True) + + # To be computed during weight loading + # self.W_QR = None + # self.W_UQ_UK = None + # self.W_UV_O = None + + kv_b_proj_weight = self.kv_b_proj.weight.T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, self.num_local_heads * + (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_local_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # TODO(simon): support append/chunked prefill by two kernels, + # or using the decode kernel somehow. + if attn_metadata.prefill_metadata and attn_metadata.decode_metadata: + raise ValueError( + "Chunked prefill is not supported when MLA is enabled.") + if attn_metadata.prefill_metadata: + return self.forward_prefill(positions, hidden_states, kv_cache, + attn_metadata) + if attn_metadata.decode_metadata: + return self.forward_decode(positions, hidden_states, kv_cache, + attn_metadata) + + def forward_prefill( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + + # HACK(simon): these need to be passed into the attention backend + # to write to the kv cache. + # TODO(simon): do we need to free these? + attn_metadata.extras = { + "kv_a": + kv_a.unsqueeze(1), # restore the head dim to write to kv cache + "k_pe": k_pe, + } + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + # B(N'V) -> BH + output, _ = self.o_proj(attn_output) + return output + + def forward_decode( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + B = hidden_states.shape[0] + + # Apply UQ and QR. + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + k_pe = k_pe.unsqueeze(1) + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + # Apply UK, q_nope (B, N, P) @ W_UK (L, N, P) -> (B, N, L) + q_nope = torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK) + # essemble q, k, and v; here v is repurposed to represent k_pe + + q = torch.empty((B, self.num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim), + dtype=q.dtype, + device=q.device) + q[..., :self.kv_lora_rank] = q_nope + q[..., self.kv_lora_rank:] = q_pe + + k = kv_a.unsqueeze(1) + # The padding is only used for kv storage. + v = torch.nn.functional.pad( + k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) + assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" + + attn_metadata.debug_layer_idx = self.debug_layer_idx + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + # idk why but the attn_output is fp32 + attn_output = attn_output.to(q.dtype) + # Apply UV, (B, N, L) @ W_UV (L, N, V) -> (B, N, V) + attn_output = torch.einsum("bnl,lnv->bnv", attn_output, self.W_UV) + attn_output = attn_output.reshape( + B, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + class DeepseekV2DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, prefix: str, + model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -344,7 +677,11 @@ def __init__( # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) - self.self_attn = DeepseekV2Attention( + if model_config.should_use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -421,6 +758,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -440,6 +778,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: DeepseekV2DecoderLayer( config, prefix, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, ), diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 74948202cbe48..159ea94f99a27 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -31,7 +31,8 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Using Torch SDPA backend.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8350177b68ade..9f3db86c47e86 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -174,10 +174,14 @@ def get_current_memory_usage(cls, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1) -> str: + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: if use_v1: logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if use_mla: + logger.info("Using FlashInfer MLA backend.") + return "vllm.attention.backends.flashinfer_mla.FlashInferMLABackend" if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 242c2c127979a..f87b48df9f823 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -25,7 +25,8 @@ class HpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: logger.info("Using HPUAttention backend.") return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f2ecec3203fb7..dc1328650d525 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -30,6 +30,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() + FLASHINFER_MLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() @@ -139,7 +140,8 @@ def is_cuda_alike(self) -> bool: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 7d414165a8188..3282c061714d3 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -30,7 +30,8 @@ class OpenVinoPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.OPENVINO: logger.info("Cannot use %s backend on OpenVINO.", selected_backend) logger.info("Using OpenVINO Attention backend.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5ef56406e1935..8888521631481 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -75,7 +75,8 @@ class RocmPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1) -> str: + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 05a3aa4305cfa..494a17633974d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -29,7 +29,8 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Using Pallas backend.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c34b5b58672e7..a5ca77f57cf47 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -27,7 +27,8 @@ class XPUPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) logger.info("Using IPEX attention backend.") diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 7ccd4571b19df..d960f53f6d4de 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -52,11 +52,13 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(self.head_size, - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - model_config.is_attention_free) + self.attn_backend = get_attn_backend( + self.head_size, + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free, + use_mla=model_config.should_use_mla) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ae8b7f97c827d..0dac50f1ac9f4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1057,6 +1057,7 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.should_use_mla, ) if needs_attn_backend else None if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()(