Skip to content

Commit

Permalink
improve embedding input
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryce1010 committed Jan 6, 2025
1 parent 448e884 commit f70bbb3
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 31 deletions.
1 change: 0 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,6 @@ def forward(
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,9 @@ def add_request(
if arrival_time is None:
arrival_time = time.time()

if isinstance(prompt, dict) and prompt.get("prompt_embeds",
None) is not None:
if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None:

Check failure on line 777 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:777:81: E501 Line too long (86 > 80)
if not prompt.get("prompt_token_ids", None):

Check failure on line 778 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM102)

vllm/engine/llm_engine.py:777:9: SIM102 Use a single `if` statement instead of nested `if` statements
prompt["prompt_token_ids"] = [0] * len(prompt["prompt_embeds"])
prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0]

Check failure on line 779 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/llm_engine.py:779:81: E501 Line too long (83 > 80)

if self.tokenizer is not None:
self._validate_token_prompt(
Expand Down
6 changes: 3 additions & 3 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

token_type_ids: NotRequired[List[int]]
"""The token type IDs of the prompt."""

Expand All @@ -154,6 +151,9 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
Expand Down
8 changes: 4 additions & 4 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _prompt_to_llm_inputs(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get('prompt_embeds'),
prompt_embeds=tokens_content.get("prompt_embeds"),
token_type_ids=token_type_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
Expand Down Expand Up @@ -362,7 +362,7 @@ def _prompt_to_llm_inputs(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=text_content.get('prompt_embeds'),
prompt_embeds=text_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -408,7 +408,7 @@ async def _prompt_to_llm_inputs_async(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get('prompt_embeds'),
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -437,7 +437,7 @@ async def _prompt_to_llm_inputs_async(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=text_content.get('prompt_embeds'),
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def forward(
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds, self.lm_head.bias)
inputs_embeds)
return hidden_states

def compute_logits(
Expand All @@ -486,7 +486,7 @@ def compute_logits(
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
sampling_metadata, self.lm_head.bias)
return logits

def sample(
Expand Down
18 changes: 9 additions & 9 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,6 @@ def output_token_ids(self,
new_output_token_ids)
self._update_cached_all_tokens()

@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds

@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: Optional[torch.Tensor]) -> None:
self._prompt_embeds = prompt_embeds

@property
def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Expand All @@ -280,6 +272,14 @@ def output_token_ids_array(self) -> array:
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids

@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds

@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds

@property
def mrope_position_delta(self) -> Optional[int]:
Expand Down Expand Up @@ -388,8 +388,8 @@ def stage(self) -> SequenceStage:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, "
f"prompt_embeds={getattr(self.prompt_embeds, 'shape', None)}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()}")

Expand Down
17 changes: 8 additions & 9 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ def __init__(

else:
self.input_tokens = input_tokens or []
self.inputs_embeds = (inputs_embeds
if inputs_embeds is not None else None)
self.inputs_embeds = (
inputs_embeds if inputs_embeds is not None else None
)
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
Expand Down Expand Up @@ -533,12 +534,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,

# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
if seq_data.prompt_embeds is not None and seq_data.get_output_len(
) == 0:
prompt_embeds = seq_data.prompt_embeds[context_len:seq_len]
if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0:

Check failure on line 537 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/model_runner.py:537:81: E501 Line too long (81 > 80)
prompt_embeds = seq_data.prompt_embeds[context_len:seq_len]
else:
seq_data.prompt_embeds = None
seq_data.prompt_embeds = None # release memory
prompt_embeds = None

token_types = seq_group_metadata.token_type_ids

inter_data.seq_lens[seq_idx] = seq_len
Expand Down Expand Up @@ -862,9 +863,7 @@ def build(self) -> ModelInputForGPU:
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds.append(
inter_data.inputs_embeds.to(self.runner.device))

inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device))
if len(inputs_embeds) == 0:
inputs_embeds = None
elif len(inputs_embeds) == 1:
Expand Down

0 comments on commit f70bbb3

Please sign in to comment.