diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh
index 73ce82c5857ab..c1c471ec974f8 100644
--- a/.buildkite/run-cpu-test.sh
+++ b/.buildkite/run-cpu-test.sh
@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
+ pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \
diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py
index 758a9b743b397..8e8862fadbf04 100644
--- a/tests/models/encoder_decoder/language/test_bart.py
+++ b/tests/models/encoder_decoder/language/test_bart.py
@@ -4,220 +4,214 @@
"""
from typing import List, Optional, Tuple, Type
-from vllm.utils import is_cpu
-
-if not is_cpu():
- # CPU backend is not currently supported with encoder/decoder models
- # skip test definitions entirely to avoid importing GPU kernel libs
- # (xFormers, etc.)
-
- import pytest
- from transformers import AutoModelForSeq2SeqLM
-
- from vllm.sequence import SampleLogprobs
-
- from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
- HfRunner, VllmRunner)
- from ....utils import multi_gpu_test
- from ...utils import check_logprobs_close
-
- MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
-
- def vllm_to_hf_output(
- vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
- decoder_prompt_type: DecoderPromptType,
- ):
- """Sanitize vllm output to be comparable with hf output."""
- output_ids, output_str, out_logprobs = vllm_output
-
- hf_output_str = output_str + ""
- if decoder_prompt_type == DecoderPromptType.NONE:
- hf_output_str = "" + hf_output_str
-
- return output_ids, hf_output_str, out_logprobs
-
- def run_test(
- hf_runner: Type[HfRunner],
- vllm_runner: Type[VllmRunner],
- prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
- decoder_prompt_type: DecoderPromptType,
- model: str,
- *,
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
- tensor_parallel_size: int,
- distributed_executor_backend: Optional[str] = None,
- ) -> None:
- '''
- Test the vLLM BART model for a variety of encoder/decoder input prompts,
- by validating it against HuggingFace (HF) BART.
-
- Arguments:
-
- * hf_runner: HuggingFace (HF) test model runner
- * vllm_runner: vLLM test model runner
- * example_encoder_decoder_prompts: test fixture which provides a
- dictionary of dummy prompts
- * model: the HF ID of the specific BART variant under test
- * dtype: the tensor datatype to employ
- * max_tokens
- * num_logprobs
- * decoder_prompt_type: key into the example_encoder_decoder_prompts
- dictionary; selects specific encoder/decoder
- prompt scenarios to test
-
- A note on using HF BART as a baseline for validating vLLM BART,
- specifically when the decoder prompt is None.
-
- The HF GenerationMixin's default behavior is to force the first
- decoded token to be if the prompt does not already contain
- (this is accomplished using a logit
- processor setting.)
-
- So when we use HF BART as our baseline for comparison, note that
- when the user provides a request with a None decoder prompt
- (i.e. a singleton encoder prompt, or else an explicit encoder/
- decoder prompt with the decoder sub-prompt set to None), HF and
- vLLM handle this in different ways:
-
- * HF will (1) tokenize the None prompt as an empty token-list,
- (2) append to the beginning, yielding
- [], (3) pass this token list to the model, and
- then (4) after computing logits during prefill, override the model
- logits & force to be the first generated token.
-
- * vLLM will (1) tokenize the None prompt as [], (2) append decoder-
- start-token to the beginning, yielding [],
- (3) pass these tokens to the model & proceed with generation.
-
- The net effect is that compared to vLLM, the list of HF *decoded* tokens
- will contain one more initial than the vLLM generated tokens,
- because vLLM's token is injected into the prompt rather than into
- the generated output. This is in spite of the fact that overall, the
- complete sequences (prompt + decoded tokens) produced by vLLM will match
- HF.
-
- So when we use HF decoded token output to validate vLLM's decoded token
- output, the testing process must account for the difference in decoded
- token sequences between vLLM and HF specifically in the
- decoder-prompt-is-None case.
-
- One option is to disable the logit processor feature that forces the
- token to be decoded (forced_bos_token_id = None), eliminating
- the problem entirely. However this is not "normal" BART usage.
-
- The other option is - only in the decoder-prompt-is-None case - to
- discard the first decoded token from the HF output before comparing it
- to vLLM.
-
- To that end, when testing the scenario where the decoder prompt is None
- (and only in that one scenario), this test skips the first HF decoded
- token during the process of validating the vLLM decoded output.
- '''
-
- # NOTE: take care of the order. run vLLM first, and then run HF.
- # vLLM needs a fresh new process without cuda initialization.
- # if we run HF first, the cuda initialization will be done and it
- # will hurt multiprocessing backend with fork method (the default).
-
- # Note: currently encoder/decoder models are only compatible with
- # enforce_eager=True. Normally this is not a problem because
- # for encoder/decoder models vLLM will
- # default to enforce_eager=True if enforce_eager
- # is left unspecified. However, the
- # VllmRunner test fixture (which wraps around the LLM class) defaults to
- # enforce_eager=False (a behavior which a number of already-exisitng
- # decoder-only unit tests expect), so when testing an encoder/decoder
- # model we must explicitly specify enforce_eager=True in the VllmRunner
- # constructor.
- with vllm_runner(
- model,
- dtype=dtype,
- tensor_parallel_size=tensor_parallel_size,
- distributed_executor_backend=distributed_executor_backend,
- enforce_eager=True) as vllm_model:
- vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
- prompts, max_tokens, num_logprobs)
-
- # Configuration settings for HF baseline
- hf_kwargs = {
- "top_k": None,
- "num_beams": 1,
- "repetition_penalty": 1.0,
- "top_p": 1.0,
- "length_penalty": 1.0,
- "early_stopping": False,
- "no_repeat_ngram_size": None,
- "min_length": 0
- }
-
- with hf_runner(model, dtype=dtype,
- auto_cls=AutoModelForSeq2SeqLM) as hf_model:
- hf_outputs = (
- hf_model.generate_encoder_decoder_greedy_logprobs_limit(
- prompts,
- max_tokens,
- num_logprobs,
- **hf_kwargs,
- ))
-
- hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
- else 0)
-
- check_logprobs_close(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=[
- vllm_to_hf_output(vllm_output, decoder_prompt_type)
- for vllm_output in vllm_outputs
- ],
- name_0="hf",
- name_1="vllm",
- num_outputs_0_skip_tokens=hf_skip_tokens,
- )
-
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
- @pytest.mark.parametrize("max_tokens", [64])
- @pytest.mark.parametrize("num_logprobs", [5])
- @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
- def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts,
- model, dtype, max_tokens, num_logprobs,
- decoder_prompt_type) -> None:
-
- run_test(
- hf_runner,
- vllm_runner,
- example_encoder_decoder_prompts[decoder_prompt_type],
- decoder_prompt_type,
- model,
- dtype=dtype,
- max_tokens=max_tokens,
- num_logprobs=num_logprobs,
- tensor_parallel_size=1,
- )
-
- @multi_gpu_test(num_gpus=2)
- @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
- @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
- @pytest.mark.parametrize("dtype", ["float"])
- @pytest.mark.parametrize("max_tokens", [64])
- @pytest.mark.parametrize("num_logprobs", [5])
- @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
- def test_models_distributed(hf_runner, vllm_runner,
- example_encoder_decoder_prompts,
- distributed_executor_backend, model, dtype,
- max_tokens, num_logprobs,
- decoder_prompt_type) -> None:
- run_test(
- hf_runner,
- vllm_runner,
- example_encoder_decoder_prompts[decoder_prompt_type],
- decoder_prompt_type,
- model,
- dtype=dtype,
- max_tokens=max_tokens,
- num_logprobs=num_logprobs,
- tensor_parallel_size=2,
- distributed_executor_backend=distributed_executor_backend,
- )
+import pytest
+from transformers import AutoModelForSeq2SeqLM
+
+from vllm.sequence import SampleLogprobs
+
+from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
+ HfRunner, VllmRunner)
+from ....utils import multi_gpu_test
+from ...utils import check_logprobs_close
+
+MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
+
+
+def vllm_to_hf_output(
+ vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
+ decoder_prompt_type: DecoderPromptType,
+):
+ """Sanitize vllm output to be comparable with hf output."""
+ output_ids, output_str, out_logprobs = vllm_output
+
+ hf_output_str = output_str + ""
+ if decoder_prompt_type == DecoderPromptType.NONE:
+ hf_output_str = "" + hf_output_str
+
+ return output_ids, hf_output_str, out_logprobs
+
+
+def run_test(
+ hf_runner: Type[HfRunner],
+ vllm_runner: Type[VllmRunner],
+ prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
+ decoder_prompt_type: DecoderPromptType,
+ model: str,
+ *,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+ tensor_parallel_size: int,
+ distributed_executor_backend: Optional[str] = None,
+) -> None:
+ '''
+ Test the vLLM BART model for a variety of encoder/decoder input prompts,
+ by validating it against HuggingFace (HF) BART.
+
+ Arguments:
+
+ * hf_runner: HuggingFace (HF) test model runner
+ * vllm_runner: vLLM test model runner
+ * example_encoder_decoder_prompts: test fixture which provides a
+ dictionary of dummy prompts
+ * model: the HF ID of the specific BART variant under test
+ * dtype: the tensor datatype to employ
+ * max_tokens
+ * num_logprobs
+ * decoder_prompt_type: key into the example_encoder_decoder_prompts
+ dictionary; selects specific encoder/decoder
+ prompt scenarios to test
+
+ A note on using HF BART as a baseline for validating vLLM BART,
+ specifically when the decoder prompt is None.
+
+ The HF GenerationMixin's default behavior is to force the first
+ decoded token to be if the prompt does not already contain
+ (this is accomplished using a logit
+ processor setting.)
+
+ So when we use HF BART as our baseline for comparison, note that
+ when the user provides a request with a None decoder prompt
+ (i.e. a singleton encoder prompt, or else an explicit encoder/
+ decoder prompt with the decoder sub-prompt set to None), HF and
+ vLLM handle this in different ways:
+
+ * HF will (1) tokenize the None prompt as an empty token-list,
+ (2) append to the beginning, yielding
+ [], (3) pass this token list to the model, and
+ then (4) after computing logits during prefill, override the model
+ logits & force to be the first generated token.
+
+ * vLLM will (1) tokenize the None prompt as [], (2) append decoder-
+ start-token to the beginning, yielding [],
+ (3) pass these tokens to the model & proceed with generation.
+
+ The net effect is that compared to vLLM, the list of HF *decoded* tokens
+ will contain one more initial than the vLLM generated tokens,
+ because vLLM's token is injected into the prompt rather than into
+ the generated output. This is in spite of the fact that overall, the
+ complete sequences (prompt + decoded tokens) produced by vLLM will match
+ HF.
+
+ So when we use HF decoded token output to validate vLLM's decoded token
+ output, the testing process must account for the difference in decoded
+ token sequences between vLLM and HF specifically in the
+ decoder-prompt-is-None case.
+
+ One option is to disable the logit processor feature that forces the
+ token to be decoded (forced_bos_token_id = None), eliminating
+ the problem entirely. However this is not "normal" BART usage.
+
+ The other option is - only in the decoder-prompt-is-None case - to
+ discard the first decoded token from the HF output before comparing it
+ to vLLM.
+
+ To that end, when testing the scenario where the decoder prompt is None
+ (and only in that one scenario), this test skips the first HF decoded
+ token during the process of validating the vLLM decoded output.
+ '''
+
+ # NOTE: take care of the order. run vLLM first, and then run HF.
+ # vLLM needs a fresh new process without cuda initialization.
+ # if we run HF first, the cuda initialization will be done and it
+ # will hurt multiprocessing backend with fork method (the default).
+
+ # Note: currently encoder/decoder models are only compatible with
+ # enforce_eager=True. Normally this is not a problem because
+ # for encoder/decoder models vLLM will
+ # default to enforce_eager=True if enforce_eager
+ # is left unspecified. However, the
+ # VllmRunner test fixture (which wraps around the LLM class) defaults to
+ # enforce_eager=False (a behavior which a number of already-exisitng
+ # decoder-only unit tests expect), so when testing an encoder/decoder
+ # model we must explicitly specify enforce_eager=True in the VllmRunner
+ # constructor.
+ with vllm_runner(model,
+ dtype=dtype,
+ tensor_parallel_size=tensor_parallel_size,
+ distributed_executor_backend=distributed_executor_backend,
+ enforce_eager=True) as vllm_model:
+ vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
+ prompts, max_tokens, num_logprobs)
+
+ # Configuration settings for HF baseline
+ hf_kwargs = {
+ "top_k": None,
+ "num_beams": 1,
+ "repetition_penalty": 1.0,
+ "top_p": 1.0,
+ "length_penalty": 1.0,
+ "early_stopping": False,
+ "no_repeat_ngram_size": None,
+ "min_length": 0
+ }
+
+ with hf_runner(model, dtype=dtype,
+ auto_cls=AutoModelForSeq2SeqLM) as hf_model:
+ hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+ prompts,
+ max_tokens,
+ num_logprobs,
+ **hf_kwargs,
+ ))
+
+ hf_skip_tokens = (1
+ if decoder_prompt_type == DecoderPromptType.NONE else 0)
+
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=[
+ vllm_to_hf_output(vllm_output, decoder_prompt_type)
+ for vllm_output in vllm_outputs
+ ],
+ name_0="hf",
+ name_1="vllm",
+ num_outputs_0_skip_tokens=hf_skip_tokens,
+ )
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
+@pytest.mark.parametrize("max_tokens", [64])
+@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
+def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
+ dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
+
+ run_test(
+ hf_runner,
+ vllm_runner,
+ example_encoder_decoder_prompts[decoder_prompt_type],
+ decoder_prompt_type,
+ model,
+ dtype=dtype,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ tensor_parallel_size=1,
+ )
+
+
+@multi_gpu_test(num_gpus=2)
+@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
+@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [64])
+@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
+def test_models_distributed(hf_runner, vllm_runner,
+ example_encoder_decoder_prompts,
+ distributed_executor_backend, model, dtype,
+ max_tokens, num_logprobs,
+ decoder_prompt_type) -> None:
+ run_test(
+ hf_runner,
+ vllm_runner,
+ example_encoder_decoder_prompts[decoder_prompt_type],
+ decoder_prompt_type,
+ model,
+ dtype=dtype,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ tensor_parallel_size=2,
+ distributed_executor_backend=distributed_executor_backend,
+ )
diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py
index 2a215331704c1..ef8d576616838 100644
--- a/vllm/attention/backends/torch_sdpa.py
+++ b/vllm/attention/backends/torch_sdpa.py
@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
+ # Begin encoder attn & enc/dec cross-attn fields...
+ # Encoder sequence lengths representation
+ encoder_seq_lens: Optional[List[int]] = None
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
+
+ # Maximum sequence length among encoder sequences
+ max_encoder_seq_len: Optional[int] = None
+
+ # Number of tokens input to encoder
+ num_encoder_tokens: Optional[int] = None
+
+ # Cross-attention memory-mapping data structures: slot mapping
+ # and block tables
+ cross_slot_mapping: Optional[torch.Tensor] = None
+ cross_block_tables: Optional[torch.Tensor] = None
+
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
@@ -82,6 +98,28 @@ def __post_init__(self):
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
+ self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
+ self.cross_attn_bias: Optional[List[torch.Tensor]] = None
+
+ @property
+ def is_all_encoder_attn_metadata_set(self):
+ '''
+ All attention metadata required for encoder attention is set.
+ '''
+ return ((self.encoder_seq_lens is not None)
+ and (self.encoder_seq_lens_tensor is not None)
+ and (self.max_encoder_seq_len is not None))
+
+ @property
+ def is_all_cross_attn_metadata_set(self):
+ '''
+ All attention metadata required for enc/dec cross-attention is set.
+
+ Superset of encoder attention required metadata.
+ '''
+ return (self.is_all_encoder_attn_metadata_set
+ and (self.cross_slot_mapping is not None)
+ and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
@@ -101,6 +139,136 @@ def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
return self
+ def get_seq_lens(
+ self,
+ attn_type: AttentionType,
+ ):
+ '''
+ Extract appropriate sequence lengths from attention metadata
+ according to attention type.
+
+ Arguments:
+
+ * attn_metadata: Attention metadata structure associated with attention
+ * attn_type: encoder attention, decoder self-attention,
+ encoder/decoder cross-attention
+
+ Returns:
+ * Appropriate sequence lengths tensor for query
+ * Appropriate sequence lengths tensor for key & value
+ '''
+
+ if attn_type == AttentionType.DECODER:
+ seq_lens_q = self.seq_lens
+ seq_lens_kv = self.seq_lens
+ elif attn_type == AttentionType.ENCODER:
+ seq_lens_q = self.encoder_seq_lens
+ seq_lens_kv = self.encoder_seq_lens
+ elif attn_type == AttentionType.ENCODER_DECODER:
+ seq_lens_q = self.seq_lens
+ seq_lens_kv = self.encoder_seq_lens
+ else:
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
+ return seq_lens_q, seq_lens_kv
+
+ def get_attn_bias(
+ self,
+ attn_type: AttentionType,
+ ) -> Optional[List[torch.Tensor]]:
+ '''
+ Extract appropriate attention bias from attention metadata
+ according to attention type.
+
+ Arguments:
+
+ * attn_metadata: Attention metadata structure associated with attention
+ * attn_type: encoder attention, decoder self-attention,
+ encoder/decoder cross-attention
+
+ Returns:
+ * Appropriate attention bias value given the attention type
+ '''
+
+ if attn_type == AttentionType.DECODER:
+ return self.attn_bias
+ elif attn_type == AttentionType.ENCODER:
+ return self.encoder_attn_bias
+ elif attn_type == AttentionType.ENCODER_DECODER:
+ return self.cross_attn_bias
+ else:
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+ def set_attn_bias(
+ self,
+ attn_bias: List[torch.Tensor],
+ attn_type: AttentionType,
+ ) -> None:
+ '''
+ Update appropriate attention bias field of attention metadata,
+ according to attention type.
+
+ Arguments:
+
+ * attn_metadata: Attention metadata structure associated with attention
+ * attn_bias: The desired attention bias value
+ * attn_type: encoder attention, decoder self-attention,
+ encoder/decoder cross-attention
+ '''
+
+ if attn_type == AttentionType.DECODER:
+ self.attn_bias = attn_bias
+ elif attn_type == AttentionType.ENCODER:
+ self.encoder_attn_bias = attn_bias
+ elif attn_type == AttentionType.ENCODER_DECODER:
+ self.cross_attn_bias = attn_bias
+ else:
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+ def get_seq_len_block_table_args(
+ self,
+ attn_type: AttentionType,
+ ) -> tuple:
+ '''
+ The particular choice of sequence-length- and block-table-related
+ attributes which should be extracted from attn_metadata is dependent
+ on the type of attention operation.
+
+ Decoder attn -> select entirely decoder self-attention-related fields
+ Encoder/decoder cross-attn -> select encoder sequence lengths &
+ cross-attn block-tables fields
+ Encoder attn -> select encoder sequence lengths fields & no block tables
+
+ Arguments:
+
+ * attn_metadata: Attention metadata structure associated with attention
+ * is_prompt: True if prefill, False otherwise
+ * attn_type: encoder attention, decoder self-attention,
+ encoder/decoder cross-attention
+
+ Returns:
+
+ * Appropriate sequence-lengths tensor
+ * Appropriate max sequence-length scalar
+ * Appropriate block tables (or None)
+ '''
+
+ if attn_type == AttentionType.DECODER:
+ # Decoder self-attention
+ # Choose max_seq_len based on whether we are in prompt_run
+ return (self.seq_lens_tensor, self.max_decode_seq_len,
+ self.block_tables)
+ elif attn_type == AttentionType.ENCODER_DECODER:
+ # Enc/dec cross-attention KVs match encoder sequence length;
+ # cross-attention utilizes special "cross" block tables
+ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
+ self.cross_block_tables)
+ elif attn_type == AttentionType.ENCODER:
+ # No block tables associated with encoder attention
+ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
+ None)
+ else:
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
@@ -171,84 +339,101 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError("Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "TorchSDPABackendImpl")
- num_tokens, hidden_size = query.shape
+ if (attn_type == AttentionType.ENCODER
+ and (not attn_metadata.is_all_encoder_attn_metadata_set)):
+ raise AttributeError("Encoder attention requires setting "
+ "encoder metadata attributes.")
+ elif (attn_type == AttentionType.ENCODER_DECODER
+ and (not attn_metadata.is_all_cross_attn_metadata_set)):
+ raise AttributeError("Encoder/decoder cross-attention "
+ "requires setting cross-attention "
+ "metadata attributes.")
+
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
- key = key.view(-1, self.num_kv_heads, self.head_size)
- value = value.view(-1, self.num_kv_heads, self.head_size)
-
- if kv_cache.numel() > 0:
+ if key is not None:
+ assert value is not None
+ key = key.view(-1, self.num_kv_heads, self.head_size)
+ value = value.view(-1, self.num_kv_heads, self.head_size)
+ else:
+ assert value is None
+
+ if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
+ # KV-cache during decoder-self- or
+ # encoder-decoder-cross-attention, but not
+ # during encoder attention.
+ #
+ # Even if there are no new key/value pairs to cache,
+ # we still need to break out key_cache and value_cache
+ # i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
- PagedAttention.write_to_paged_cache(key, value, key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype, k_scale,
- v_scale)
- if attn_metadata.is_prompt:
+ if (key is not None) and (value is not None):
+ if attn_type == AttentionType.ENCODER_DECODER:
+ # Update cross-attention KV cache (prefill-only)
+ # During cross-attention decode, key & value will be None,
+ # preventing this IF-statement branch from running
+ updated_slot_mapping = attn_metadata.cross_slot_mapping
+ else:
+ # Update self-attention KV cache (prefill/decode)
+ updated_slot_mapping = attn_metadata.slot_mapping
+
+ PagedAttention.write_to_paged_cache(key, value, key_cache,
+ value_cache,
+ updated_slot_mapping,
+ self.kv_cache_dtype,
+ k_scale, v_scale)
+
+ if attn_type != AttentionType.ENCODER:
+ # Decoder self-attention supports chunked prefill.
+ # Encoder/decoder cross-attention requires no chunked
+ # prefill (100% prefill or 100% decode tokens, no mix)
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ else:
+ # Encoder attention - chunked prefill is not applicable;
+ # derive token-count from query shape & and treat them
+ # as 100% prefill tokens
+ assert attn_metadata.num_encoder_tokens is not None
+ num_prefill_tokens = attn_metadata.num_encoder_tokens
+ num_decode_tokens = 0
+
+ if attn_type == AttentionType.DECODER:
+ # Only enforce this shape-constraint for decoder
+ # self-attention
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+ if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
- or attn_metadata.block_tables.numel() == 0):
- if self.num_kv_heads != self.num_heads:
- key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
- value = value.repeat_interleave(self.num_queries_per_kv,
- dim=1)
-
- if attn_metadata.attn_bias is None:
- if self.alibi_slopes is not None:
- att_masks = _make_alibi_bias(
- self.alibi_slopes, query.dtype,
- attn_metadata.seq_lens) # type: ignore
- elif self.sliding_window is not None:
- att_masks = _make_sliding_window_bias(
- attn_metadata.seq_lens, self.sliding_window,
- query.dtype) # type: ignore
- else:
- att_masks = [None] * len(attn_metadata.seq_lens)
- attn_metadata.attn_bias = att_masks
-
- query = query.movedim(0, query.dim() - 2)
- key = key.movedim(0, key.dim() - 2)
- value = value.movedim(0, value.dim() - 2)
-
- start = 0
- output = torch.empty(
- (num_tokens, self.num_heads, self.head_size),
- dtype=query.dtype)
- for seq_len, mask in zip(attn_metadata.seq_lens,
- attn_metadata.attn_bias):
- end = start + seq_len
- sub_out = scaled_dot_product_attention(
- query[None, :, start:end, :],
- key[None, :, start:end, :],
- value[None, :, start:end, :],
- attn_mask=mask,
- dropout_p=0.0,
- is_causal=not self.need_mask,
- scale=self.scale).squeeze(0).movedim(
- query.dim() - 2, 0)
- output[start:end, :, :] = sub_out
- start = end
+ or prefill_meta.block_tables.numel() == 0):
+ output = self._run_sdpa_forward(query,
+ key,
+ value,
+ prefill_meta,
+ attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
- else:
+ if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
+ (
+ seq_lens_arg,
+ max_seq_len_arg,
+ block_tables_arg,
+ ) = decode_meta.get_seq_len_block_table_args(attn_type)
+
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.seq_lens_tensor,
- attn_metadata.max_decode_seq_len,
+ block_tables_arg,
+ seq_lens_arg,
+ max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@@ -260,6 +445,59 @@ def forward(
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
+ def _run_sdpa_forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_metadata: TorchSDPAMetadata,
+ attn_type: AttentionType = AttentionType.DECODER,
+ ):
+ if self.num_kv_heads != self.num_heads:
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
+ value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
+
+ attn_masks = attn_metadata.get_attn_bias(attn_type)
+ if attn_masks is None:
+ if self.alibi_slopes is not None:
+ attn_masks = _make_alibi_bias(
+ self.alibi_slopes, query.dtype,
+ attn_metadata.seq_lens) # type: ignore
+ elif self.sliding_window is not None:
+ assert attn_metadata.seq_lens is not None
+ attn_masks = _make_sliding_window_bias(
+ attn_metadata.seq_lens, self.sliding_window,
+ query.dtype) # type: ignore
+ else:
+ seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
+ attn_masks = [None] * len(seq_lens)
+ attn_metadata.set_attn_bias(attn_masks, attn_type)
+
+ output = torch.empty_like(query)
+ query = query.movedim(0, query.dim() - 2)
+ key = key.movedim(0, key.dim() - 2)
+ value = value.movedim(0, value.dim() - 2)
+
+ causal_attn = (attn_type == AttentionType.DECODER)
+
+ seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
+ start_q, start_kv = 0, 0
+ for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
+ attn_masks):
+ end_q = start_q + seq_len_q
+ end_kv = start_kv + seq_len_kv
+ sub_out = scaled_dot_product_attention(
+ query[None, :, start_q:end_q, :],
+ key[None, :, start_kv:end_kv, :],
+ value[None, :, start_kv:end_kv, :],
+ attn_mask=mask,
+ dropout_p=0.0,
+ is_causal=causal_attn and not self.need_mask,
+ scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
+ output[start_q:end_q, :, :] = sub_out
+ start_q, start_kv = end_q, end_kv
+ return output
+
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py
new file mode 100644
index 0000000000000..8ebbf6db939bc
--- /dev/null
+++ b/vllm/worker/cpu_enc_dec_model_runner.py
@@ -0,0 +1,311 @@
+import dataclasses
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
+
+import torch
+
+from vllm.attention import AttentionMetadata
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.multimodal import MultiModalInputs
+from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
+from vllm.utils import make_tensor_with_pad
+from vllm.worker.cpu_model_runner import (CPUModelRunner,
+ ModelInputForCPUBuilder,
+ ModelInputForCPUWithSamplingMetadata)
+from vllm.worker.model_runner_base import (
+ _add_attn_metadata_broadcastable_dict,
+ _add_sampling_metadata_broadcastable_dict)
+
+if TYPE_CHECKING:
+ from vllm.attention.backends.abstract import AttentionBackend
+
+
+@dataclasses.dataclass(frozen=True)
+class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
+ """
+ Used by the EncoderDecoderModelRunner.
+ """
+ encoder_input_tokens: Optional[torch.Tensor] = None
+ encoder_input_positions: Optional[torch.Tensor] = None
+
+ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
+ tensor_dict = {
+ "input_tokens": self.input_tokens,
+ "input_positions": self.input_positions,
+ "encoder_input_tokens": self.encoder_input_tokens,
+ "encoder_input_positions": self.encoder_input_positions,
+ }
+ _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
+ _add_sampling_metadata_broadcastable_dict(tensor_dict,
+ self.sampling_metadata)
+ return tensor_dict
+
+ @classmethod
+ def from_broadcasted_tensor_dict(
+ cls,
+ tensor_dict: Dict[str, Any],
+ attn_backend: Optional["AttentionBackend"] = None,
+ ) -> "EncoderDecoderModelInputForCPU":
+ return cast(
+ EncoderDecoderModelInputForCPU,
+ super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
+
+
+class CPUEncoderDecoderModelRunner(CPUModelRunner):
+ _model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
+ EncoderDecoderModelInputForCPU)
+ _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
+
+ def _list_to_int32_tensor(
+ self,
+ _list: List[int],
+ ) -> torch.Tensor:
+ return torch.tensor(_list, dtype=torch.int32, device=self.device)
+
+ def _list_to_long_tensor(
+ self,
+ _list: List[int],
+ ) -> torch.Tensor:
+ return torch.tensor(_list, dtype=torch.long, device=self.device)
+
+ def _empty_int32_tensor(self) -> torch.Tensor:
+ return self._list_to_int32_tensor([])
+
+ def _empty_long_tensor(self) -> torch.Tensor:
+ return self._list_to_long_tensor([])
+
+ def make_model_input_from_broadcasted_tensor_dict(
+ self, tensor_dict: Dict[str,
+ Any]) -> EncoderDecoderModelInputForCPU:
+ return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
+ tensor_dict,
+ attn_backend=self.attn_backend,
+ )
+
+ def prepare_model_input(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ virtual_engine: int = 0,
+ finished_requests_ids: Optional[List[str]] = None
+ ) -> EncoderDecoderModelInputForCPU:
+ model_input = super().prepare_model_input(seq_group_metadata_list,
+ virtual_engine,
+ finished_requests_ids)
+ model_input = cast(EncoderDecoderModelInputForCPU, model_input)
+ (
+ attn_metadata,
+ encoder_input_tokens_tensor,
+ encoder_input_positions_tensor,
+ ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
+ model_input)
+ return dataclasses.replace(
+ model_input,
+ attn_metadata=attn_metadata,
+ encoder_input_tokens=encoder_input_tokens_tensor,
+ encoder_input_positions=encoder_input_positions_tensor,
+ )
+
+ def _prepare_encoder_model_input_tensors(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ model_input: EncoderDecoderModelInputForCPU,
+ ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
+ Optional[torch.Tensor]]:
+ """Helper method to prepare the encoder- and cross-attn-related
+ model inputs based on a given sequence group. These additional inputs
+ are used to augment an already-computed `EncoderDecoderModelInput`
+ data structure which already has decoder-related model inputs
+ populated.
+
+ Sets the following attn_metadata fields:
+ * `num_encoder_tokens`
+ * `encoder_seq_lens`
+ * `encoder_seq_lens_tensor`
+ * `max_encoder_seq_len`
+ * `cross_slot_mapping`
+ * `cross_block_tables`
+
+ Constructs a new model inputs data structure, based on
+ (1) the existing fields in the `model_inputs` argument,
+ and (2) the following additional fields which are
+ computed (or in the case of `attn_metadata`, updated)
+ by this function:
+ * attn_metadata
+ * encoder_input_tokens
+ * encoder_input_positions
+
+ Arguments:
+
+ * seq_group_metadata_list: list of sequence groups for which to
+ compute inputs
+ * model_inputs: model inputs data structure with decoder-oriented
+ fields already computed.
+
+ Return:
+
+ * Updated model inputs data structure
+ """
+
+ if len(seq_group_metadata_list) == 0:
+ return (model_input.attn_metadata, None, None)
+
+ # Since we are not supporting chunked prefill either the entire
+ # batch is prefill or it is decode
+ is_prompt = seq_group_metadata_list[0].is_prompt
+
+ # Build encoder inputs
+ encoder_seq_lens: List[int] = []
+ if is_prompt:
+ # Prefill phase.
+ cross_block_tables = self._empty_int32_tensor().view(
+ len(seq_group_metadata_list), -1)
+
+ # Extract input tokens/positions, cross-attention slot-mapping,
+ # & seq len from each sequence group metadata
+ (
+ encoder_input_tokens,
+ encoder_input_positions,
+ cross_slot_mapping,
+ ) = (
+ [],
+ [],
+ [],
+ )
+ for seq_group_metadata in seq_group_metadata_list:
+ # Build seq lens
+ seq_len = seq_group_metadata.encoder_seq_data.get_len()
+ token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
+ encoder_seq_lens.append(seq_len)
+
+ # Build slot mapping
+ for i in range(0, seq_len):
+ block_number = seq_group_metadata.cross_block_table[
+ i // self.block_size]
+ block_offset = i % self.block_size
+ slot = block_number * self.block_size + block_offset
+ cross_slot_mapping.append(slot)
+
+ # Build encoder input tokens
+ encoder_input_tokens.extend(token_ids)
+ encoder_input_positions.extend(list(range(0, seq_len)))
+
+ # Convert tokens/positions & cross-attention
+ # slot-mapping to encoder input tensors
+ encoder_input_tokens_tensor = self._list_to_long_tensor(
+ encoder_input_tokens)
+ encoder_input_positions_tensor = self._list_to_long_tensor(
+ encoder_input_positions)
+ cross_slot_mapping_tensor = self._list_to_long_tensor(
+ cross_slot_mapping)
+
+ else:
+ # Decode phase.
+ encoder_input_tokens_tensor = self._empty_long_tensor()
+ encoder_input_positions_tensor = self._empty_long_tensor()
+ cross_slot_mapping_tensor = self._empty_long_tensor()
+ # Extract cross-attention block tables &
+ # seq len from each sequence group metadata.
+ # Cross-attention block tables are empty
+ # during vLLM memory profiling.
+ cross_block_tables = []
+ for seq_group_metadata in seq_group_metadata_list:
+ for _ in range(len(seq_group_metadata.seq_data)):
+ encoder_seq_lens.append(
+ seq_group_metadata.encoder_seq_data.get_len())
+ cross_block_table = seq_group_metadata.cross_block_table
+ cross_block_tables.append([] if (
+ cross_block_table is None) else cross_block_table)
+
+ max_len_of_block_table = max(
+ len(block_table) for block_table in cross_block_tables)
+
+ cross_block_tables = make_tensor_with_pad(
+ cross_block_tables,
+ max_len=max_len_of_block_table,
+ pad=0,
+ dtype=torch.int32,
+ device=self.device,
+ )
+
+ # Compute encoder sequence lengths & encoder
+ # sequence starting offset tensors
+ max_encoder_seq_len = max(encoder_seq_lens, default=0)
+ encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
+ encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
+ 1,
+ dtype=torch.int32,
+ device=self.device)
+ torch.cumsum(encoder_seq_lens_tensor,
+ dim=0,
+ dtype=encoder_seq_start_loc.dtype,
+ out=encoder_seq_start_loc[1:])
+
+ # Update attention metadata with encoder-oriented attributes
+ attn_metadata = model_input.attn_metadata
+ assert attn_metadata is not None
+ (
+ attn_metadata.num_encoder_tokens,
+ attn_metadata.encoder_seq_lens,
+ attn_metadata.encoder_seq_lens_tensor,
+ attn_metadata.max_encoder_seq_len,
+ attn_metadata.cross_slot_mapping,
+ attn_metadata.cross_block_tables,
+ ) = (
+ sum(encoder_seq_lens),
+ encoder_seq_lens,
+ encoder_seq_lens_tensor,
+ max_encoder_seq_len,
+ cross_slot_mapping_tensor,
+ cross_block_tables,
+ )
+
+ return (attn_metadata, encoder_input_tokens_tensor,
+ encoder_input_positions_tensor)
+
+ @torch.no_grad()
+ def execute_model(
+ self,
+ model_input: EncoderDecoderModelInputForCPU,
+ kv_caches: List[torch.Tensor],
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ num_steps: int = 1,
+ ) -> Optional[List[SamplerOutput]]:
+ if num_steps > 1:
+ raise ValueError(
+ "CPU worker does not support multi-step execution.")
+
+ model_executable = self.model
+ execute_model_kwargs = {
+ "input_ids":
+ model_input.input_tokens,
+ "positions":
+ model_input.input_positions,
+ "encoder_input_ids":
+ model_input.encoder_input_tokens,
+ "encoder_positions":
+ model_input.encoder_input_positions,
+ "kv_caches":
+ kv_caches,
+ "attn_metadata":
+ model_input.attn_metadata,
+ **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
+ device=self.device),
+ "intermediate_tensors":
+ intermediate_tensors,
+ }
+
+ hidden_states = model_executable(**execute_model_kwargs)
+
+ # Compute the logits.
+ logits = self.model.compute_logits(hidden_states,
+ model_input.sampling_metadata)
+
+ # Only perform sampling in the driver worker.
+ if not self.is_driver_worker:
+ return []
+
+ # Sample the next token.
+ output = self.model.sample(
+ logits=logits,
+ sampling_metadata=model_input.sampling_metadata,
+ )
+ return [output]
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index 534d167d994fe..a03c562532179 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -19,7 +19,7 @@
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
-from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
+from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@@ -434,10 +434,6 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # Set after init_Model
- if self.model_config.is_encoder_decoder_model:
- raise NotImplementedError(
- STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
-
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
@@ -459,8 +455,8 @@ def load_model(self) -> None:
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
- ) -> ModelInputForCPU:
- return ModelInputForCPU.from_broadcasted_tensor_dict(
+ ) -> ModelInputForCPUWithSamplingMetadata:
+ return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 5e36fba6ccdea..7384ffcb2c5e5 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -1,5 +1,5 @@
"""A CPU worker class."""
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.distributed
@@ -15,6 +15,7 @@
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
+from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
@@ -163,7 +164,10 @@ def __init__(
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
- self.model_runner: CPUModelRunner = CPUModelRunner(
+ ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
+ if self._is_encoder_decoder_model():
+ ModelRunnerClass = CPUEncoderDecoderModelRunner
+ self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
@@ -205,6 +209,9 @@ def stop_profile(self):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
+ def _is_encoder_decoder_model(self):
+ return self.model_config.is_encoder_decoder_model
+
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)