Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] add input embedding #11684

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,6 @@ def forward(
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,13 @@
if arrival_time is None:
arrival_time = time.time()

if isinstance(prompt, dict) and prompt.get("prompt_embeds",
None) is not None:
if not prompt.get("prompt_token_ids", None):

Check failure on line 491 in vllm/engine/async_llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM102)

vllm/engine/async_llm_engine.py:489:9: SIM102 Use a single `if` statement instead of nested `if` statements
prompt["prompt_token_ids"] = [
0
] * prompt["prompt_embeds"].shape[0]

if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@
if arrival_time is None:
arrival_time = time.time()

if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None:

Check failure on line 777 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:777:81: E501 Line too long (86 > 80)
if not prompt.get("prompt_token_ids", None):

Check failure on line 778 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM102)

vllm/engine/llm_engine.py:777:9: SIM102 Use a single `if` statement instead of nested `if` statements
prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0]

Check failure on line 779 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:779:81: E501 Line too long (83 > 80)

if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tqdm import tqdm
from typing_extensions import deprecated
import torch

from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
Expand Down Expand Up @@ -356,7 +357,7 @@
) -> List[RequestOutput]:
...

@deprecate_kwargs(

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Overloaded function implementation does not accept all possible arguments of signature 2 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Overloaded function implementation does not accept all possible arguments of signature 3 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Overloaded function implementation does not accept all possible arguments of signature 6 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Overloaded function implementation does not accept all possible arguments of signature 2 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Overloaded function implementation does not accept all possible arguments of signature 3 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Overloaded function implementation does not accept all possible arguments of signature 6 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Overloaded function implementation does not accept all possible arguments of signature 2 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Overloaded function implementation does not accept all possible arguments of signature 3 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Overloaded function implementation does not accept all possible arguments of signature 6 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Overloaded function implementation does not accept all possible arguments of signature 2 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Overloaded function implementation does not accept all possible arguments of signature 3 [misc]

Check failure on line 360 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Overloaded function implementation does not accept all possible arguments of signature 6 [misc]
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
Expand All @@ -368,6 +369,7 @@
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand Down Expand Up @@ -432,6 +434,9 @@
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)

if prompt_embeds is not None:
parsed_prompts.prompt_embeds = prompt_embeds

if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
Expand Down Expand Up @@ -1086,6 +1091,7 @@
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
prompt_embeds: Optional[torch.Tensor] = None,
):
# skip_tokenizer_init is now checked in engine

Expand Down
14 changes: 13 additions & 1 deletion vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class TextPrompt(TypedDict):
prompt: str
"""The input text to be tokenized before passing to the model."""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
Expand All @@ -39,6 +42,9 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

token_type_ids: NotRequired[List[int]]
"""A list of token type IDs to pass to the cross encoder model."""

Expand Down Expand Up @@ -145,6 +151,9 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
Expand Down Expand Up @@ -180,6 +189,7 @@ def token_inputs(
prompt_token_ids: List[int],
token_type_ids: Optional[List[int]] = None,
prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
Expand All @@ -193,6 +203,8 @@ def token_inputs(
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if prompt_embeds is not None:
inputs["prompt_embeds"] = prompt_embeds
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
Expand Down Expand Up @@ -275,7 +287,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None
return inputs.get("prompt_embeds")

assert_never(inputs) # type: ignore[arg-type]

Expand Down
3 changes: 3 additions & 0 deletions vllm/inputs/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def parse_singleton_prompt(
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)

elif "prompt_embeds" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt)

raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")


Expand Down
4 changes: 4 additions & 0 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _prompt_to_llm_inputs(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
token_type_ids=token_type_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
Expand Down Expand Up @@ -361,6 +362,7 @@ def _prompt_to_llm_inputs(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=text_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -406,6 +408,7 @@ async def _prompt_to_llm_inputs_async(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -434,6 +437,7 @@ async def _prompt_to_llm_inputs_async(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break models that don't have bias weights. Can you read this from the HF config?

quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
Expand Down Expand Up @@ -485,7 +486,7 @@ def compute_logits(
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
sampling_metadata, self.lm_head.bias)
return logits

def sample(
Expand Down
14 changes: 13 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class SequenceData(msgspec.Struct,
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

_prompt_embeds: Optional[torch.Tensor] = None

### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
Expand Down Expand Up @@ -253,7 +255,7 @@ def prompt_token_ids_array(self) -> array:
@property
def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids)

@output_token_ids.setter
def output_token_ids(self,
new_output_token_ids: GenericSequence[int]) -> None:
Expand All @@ -270,6 +272,14 @@ def output_token_ids_array(self) -> array:
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids

@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds

@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds

@property
def mrope_position_delta(self) -> Optional[int]:
Expand Down Expand Up @@ -378,6 +388,7 @@ def stage(self) -> SequenceStage:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
Expand Down Expand Up @@ -417,6 +428,7 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request

self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.data.prompt_embeds = self.inputs.prompt_embeds
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down
59 changes: 59 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
Expand All @@ -103,6 +104,7 @@
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
Expand Down Expand Up @@ -153,6 +155,7 @@
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
Expand Down Expand Up @@ -192,6 +195,7 @@

def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
Expand Down Expand Up @@ -219,6 +223,7 @@

# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
Expand Down Expand Up @@ -280,6 +285,11 @@
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()

if inputs_embeds is not None:
self.inputs_embeds = inputs_embeds
else:
self.inputs_embeds = None

if input_positions:
self.input_positions = input_positions
else:
Expand Down Expand Up @@ -354,6 +364,9 @@

else:
self.input_tokens = input_tokens or []
self.inputs_embeds = (
inputs_embeds if inputs_embeds is not None else None
)
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
Expand Down Expand Up @@ -399,6 +412,26 @@
self.lora_index_mapping = []
self.lora_prompt_mapping = []

def __repr__(self) -> str:
return (
f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds={getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")

def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
Expand Down Expand Up @@ -501,12 +534,19 @@

# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0:

Check failure on line 537 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/model_runner.py:537:81: E501 Line too long (81 > 80)
prompt_embeds = seq_data.prompt_embeds[context_len:seq_len]
else:
seq_data.prompt_embeds = None # release memory
prompt_embeds = None

token_types = seq_group_metadata.token_type_ids

inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
Expand Down Expand Up @@ -815,12 +855,21 @@
"""
# Combine and flatten intermediate data.
input_tokens = []
inputs_embeds = []
token_types = []
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device))
if len(inputs_embeds) == 0:
inputs_embeds = None
elif len(inputs_embeds) == 1:
inputs_embeds = inputs_embeds[0]
else:
inputs_embeds = torch.cat(inputs_embeds, dim=0)

if not input_tokens:
# This may happen when all prefill requests hit
Expand Down Expand Up @@ -972,6 +1021,7 @@

return self.model_input_cls(
input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -1698,6 +1748,9 @@
self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
**{
"inputs_embeds": model_input.inputs_embeds,
} if model_input.inputs_embeds is not None else {},
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
Expand Down Expand Up @@ -1812,6 +1865,9 @@
# check if the current run is prefill
is_prefill_run = prefill_meta is not None

if self.vllm_config.kv_transfer_config is None:
return False

return self.vllm_config.kv_transfer_config.is_kv_consumer and (
not is_profile_run) and is_prefill_run

Expand All @@ -1837,6 +1893,9 @@
# check if the current run is prefill
is_prefill_run = prefill_meta is not None

if self.vllm_config.kv_transfer_config is None:
return False

return self.vllm_config.kv_transfer_config.is_kv_producer and (
not is_profile_run) and is_prefill_run

Expand Down
Loading