From f70bbb3653dab29091339d0f24edbdb2d9daa3b7 Mon Sep 17 00:00:00 2001 From: Bryce1010 Date: Mon, 6 Jan 2025 19:04:34 +0800 Subject: [PATCH] improve embedding input --- vllm/attention/backends/flash_attn.py | 1 - vllm/engine/llm_engine.py | 5 ++--- vllm/inputs/data.py | 6 +++--- vllm/inputs/preprocess.py | 8 ++++---- vllm/model_executor/models/qwen2.py | 4 ++-- vllm/sequence.py | 18 +++++++++--------- vllm/worker/model_runner.py | 17 ++++++++--------- 7 files changed, 28 insertions(+), 31 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..71db92626afc6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -719,7 +719,6 @@ def forward( prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bfc26acbacf29..687f1fef6a102 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -774,10 +774,9 @@ def add_request( if arrival_time is None: arrival_time = time.time() - if isinstance(prompt, dict) and prompt.get("prompt_embeds", - None) is not None: + if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None: if not prompt.get("prompt_token_ids", None): - prompt["prompt_token_ids"] = [0] * len(prompt["prompt_embeds"]) + prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] if self.tokenizer is not None: self._validate_token_prompt( diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 3e2148a78efb7..88888f4a6195c 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -143,9 +143,6 @@ class TokenInputs(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" - prompt_embeds: NotRequired[torch.Tensor] - """The embeddings of the prompt, if available.""" - token_type_ids: NotRequired[List[int]] """The token type IDs of the prompt.""" @@ -154,6 +151,9 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt, if available.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 0427643afe46f..8cd46cb2063d8 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -332,7 +332,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, @@ -362,7 +362,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=text_content.get('prompt_embeds'), + prompt_embeds=text_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -408,7 +408,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -437,7 +437,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=text_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 2c3c641f482aa..5c7addde4e750 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -477,7 +477,7 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds, self.lm_head.bias) + inputs_embeds) return hidden_states def compute_logits( @@ -486,7 +486,7 @@ def compute_logits( sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + sampling_metadata, self.lm_head.bias) return logits def sample( diff --git a/vllm/sequence.py b/vllm/sequence.py index ee6c1a4f903eb..374c06bb7227d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -263,14 +263,6 @@ def output_token_ids(self, new_output_token_ids) self._update_cached_all_tokens() - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: Optional[torch.Tensor]) -> None: - self._prompt_embeds = prompt_embeds - @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -280,6 +272,14 @@ def output_token_ids_array(self) -> array: """ assert isinstance(self._output_token_ids, array) return self._output_token_ids + + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: + self._prompt_embeds = prompt_embeds @property def mrope_position_delta(self) -> Optional[int]: @@ -388,8 +388,8 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, " f"output_token_ids={self.output_token_ids}, " - f"prompt_embeds={getattr(self.prompt_embeds, 'shape', None)}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()}") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 550782ba878a2..d5d8b1c9df162 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -364,8 +364,9 @@ def __init__( else: self.input_tokens = input_tokens or [] - self.inputs_embeds = (inputs_embeds - if inputs_embeds is not None else None) + self.inputs_embeds = ( + inputs_embeds if inputs_embeds is not None else None + ) self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -533,12 +534,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] - if seq_data.prompt_embeds is not None and seq_data.get_output_len( - ) == 0: - prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0: + prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] else: - seq_data.prompt_embeds = None + seq_data.prompt_embeds = None # release memory prompt_embeds = None + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len @@ -862,9 +863,7 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds.append( - inter_data.inputs_embeds.to(self.runner.device)) - + inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device)) if len(inputs_embeds) == 0: inputs_embeds = None elif len(inputs_embeds) == 1: