From 4f95ffee6f40198911ee824ed06d645fe9678511 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 7 Oct 2024 14:50:35 +0800 Subject: [PATCH] [Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089) --- .buildkite/run-cpu-test.sh | 1 + .../encoder_decoder/language/test_bart.py | 428 +++++++++--------- vllm/attention/backends/torch_sdpa.py | 360 ++++++++++++--- vllm/worker/cpu_enc_dec_model_runner.py | 311 +++++++++++++ vllm/worker/cpu_model_runner.py | 10 +- vllm/worker/cpu_worker.py | 11 +- 6 files changed, 834 insertions(+), 287 deletions(-) create mode 100644 vllm/worker/cpu_enc_dec_model_runner.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 73ce82c5857ab..c1c471ec974f8 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator + pytest -v -s tests/models/encoder_decoder/language pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 758a9b743b397..8e8862fadbf04 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -4,220 +4,214 @@ """ from typing import List, Optional, Tuple, Type -from vllm.utils import is_cpu - -if not is_cpu(): - # CPU backend is not currently supported with encoder/decoder models - # skip test definitions entirely to avoid importing GPU kernel libs - # (xFormers, etc.) - - import pytest - from transformers import AutoModelForSeq2SeqLM - - from vllm.sequence import SampleLogprobs - - from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) - from ....utils import multi_gpu_test - from ...utils import check_logprobs_close - - MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] - - def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, - ): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, - ) -> None: - ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. - - Arguments: - - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test - - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be if the prompt does not already contain - (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append to the beginning, yielding - [], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [], (2) append decoder- - start-token to the beginning, yielding [], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial than the vLLM generated tokens, - because vLLM's token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE - else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) - - @pytest.mark.parametrize("model", MODELS) - @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) - @pytest.mark.parametrize("max_tokens", [64]) - @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) - def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, - model, dtype, max_tokens, num_logprobs, - decoder_prompt_type) -> None: - - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - @multi_gpu_test(num_gpus=2) - @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) - @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) - @pytest.mark.parametrize("dtype", ["float"]) - @pytest.mark.parametrize("max_tokens", [64]) - @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) - def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs + +from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) +from ....utils import multi_gpu_test +from ...utils import check_logprobs_close + +MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + ''' + Test the vLLM BART model for a variety of encoder/decoder input prompts, + by validating it against HuggingFace (HF) BART. + + Arguments: + + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + * model: the HF ID of the specific BART variant under test + * dtype: the tensor datatype to employ + * max_tokens + * num_logprobs + * decoder_prompt_type: key into the example_encoder_decoder_prompts + dictionary; selects specific encoder/decoder + prompt scenarios to test + + A note on using HF BART as a baseline for validating vLLM BART, + specifically when the decoder prompt is None. + + The HF GenerationMixin's default behavior is to force the first + decoded token to be if the prompt does not already contain + (this is accomplished using a logit + processor setting.) + + So when we use HF BART as our baseline for comparison, note that + when the user provides a request with a None decoder prompt + (i.e. a singleton encoder prompt, or else an explicit encoder/ + decoder prompt with the decoder sub-prompt set to None), HF and + vLLM handle this in different ways: + + * HF will (1) tokenize the None prompt as an empty token-list, + (2) append to the beginning, yielding + [], (3) pass this token list to the model, and + then (4) after computing logits during prefill, override the model + logits & force to be the first generated token. + + * vLLM will (1) tokenize the None prompt as [], (2) append decoder- + start-token to the beginning, yielding [], + (3) pass these tokens to the model & proceed with generation. + + The net effect is that compared to vLLM, the list of HF *decoded* tokens + will contain one more initial than the vLLM generated tokens, + because vLLM's token is injected into the prompt rather than into + the generated output. This is in spite of the fact that overall, the + complete sequences (prompt + decoded tokens) produced by vLLM will match + HF. + + So when we use HF decoded token output to validate vLLM's decoded token + output, the testing process must account for the difference in decoded + token sequences between vLLM and HF specifically in the + decoder-prompt-is-None case. + + One option is to disable the logit processor feature that forces the + token to be decoded (forced_bos_token_id = None), eliminating + the problem entirely. However this is not "normal" BART usage. + + The other option is - only in the decoder-prompt-is-None case - to + discard the first decoded token from the HF output before comparing it + to vLLM. + + To that end, when testing the scenario where the decoder prompt is None + (and only in that one scenario), this test skips the first HF decoded + token during the process of validating the vLLM decoded output. + ''' + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + hf_skip_tokens = (1 + if decoder_prompt_type == DecoderPromptType.NONE else 0) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2a215331704c1..ef8d576616838 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): slot_mapping: torch.Tensor seq_lens: Optional[List[int]] + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -82,6 +98,28 @@ def __post_init__(self): # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: @@ -101,6 +139,136 @@ def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: return self + def get_seq_lens( + self, + attn_type: AttentionType, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if attn_type == AttentionType.DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: AttentionType, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if attn_type == AttentionType.DECODER: + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: AttentionType, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if attn_type == AttentionType.DECODER: + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: AttentionType, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): @@ -171,84 +339,101 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TorchSDPABackendImpl") - num_tokens, hidden_size = query.shape + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, k_scale, - v_scale) - if attn_metadata.is_prompt: + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + k_scale, v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None if (kv_cache.numel() == 0 - or attn_metadata.block_tables.numel() == 0): - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=1) - - if attn_metadata.attn_bias is None: - if self.alibi_slopes is not None: - att_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore - elif self.sliding_window is not None: - att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore - else: - att_masks = [None] * len(attn_metadata.seq_lens) - attn_metadata.attn_bias = att_masks - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for seq_len, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): - end = start + seq_len - sub_out = scaled_dot_product_attention( - query[None, :, start:end, :], - key[None, :, start:end, :], - value[None, :, start:end, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=not self.need_mask, - scale=self.scale).squeeze(0).movedim( - query.dim() - 2, 0) - output[start:end, :, :] = sub_out - start = end + or prefill_meta.block_tables.numel() == 0): + output = self._run_sdpa_forward(query, + key, + value, + prefill_meta, + attn_type=attn_type) else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + output = PagedAttention.forward_decode( query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, - attn_metadata.max_decode_seq_len, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -260,6 +445,59 @@ def forward( # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) + def _run_sdpa_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + output = torch.empty_like(query) + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and not self.need_mask, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + return output + def _make_alibi_bias( alibi_slopes: torch.Tensor, diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py new file mode 100644 index 0000000000000..8ebbf6db939bc --- /dev/null +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -0,0 +1,311 @@ +import dataclasses +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast + +import torch + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalInputs +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import (CPUModelRunner, + ModelInputForCPUBuilder, + ModelInputForCPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInputForCPU": + return cast( + EncoderDecoderModelInputForCPU, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class CPUEncoderDecoderModelRunner(CPUModelRunner): + _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( + EncoderDecoderModelInputForCPU) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, + Any]) -> EncoderDecoderModelInputForCPU: + return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInputForCPU: + model_input = super().prepare_model_input(seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + model_input = cast(EncoderDecoderModelInputForCPU, model_input) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input) + return dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInputForCPU, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) + + @torch.no_grad() + def execute_model( + self, + model_input: EncoderDecoderModelInputForCPU, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "encoder_input_ids": + model_input.encoder_input_tokens, + "encoder_positions": + model_input.encoder_input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + "intermediate_tensors": + intermediate_tensors, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 534d167d994fe..a03c562532179 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -19,7 +19,7 @@ MultiModalInputs) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -434,10 +434,6 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model - if self.model_config.is_encoder_decoder_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) - @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. @@ -459,8 +455,8 @@ def load_model(self) -> None: def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], - ) -> ModelInputForCPU: - return ModelInputForCPU.from_broadcasted_tensor_dict( + ) -> ModelInputForCPUWithSamplingMetadata: + return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 tensor_dict, attn_backend=self.attn_backend, ) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5e36fba6ccdea..7384ffcb2c5e5 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type import torch import torch.distributed @@ -15,6 +15,7 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerInput) @@ -163,7 +164,10 @@ def __init__( else: self.local_omp_cpuid = omp_cpuids.split("|")[rank] - self.model_runner: CPUModelRunner = CPUModelRunner( + ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner + if self._is_encoder_decoder_model(): + ModelRunnerClass = CPUEncoderDecoderModelRunner + self.model_runner: CPUModelRunner = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -205,6 +209,9 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() + def _is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model + def init_device(self) -> None: if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)