Skip to content

Commit

Permalink
[Core] Make encoder-decoder inputs a nested structure to be more comp…
Browse files Browse the repository at this point in the history
…osable (#9604)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Nov 5, 2024
1 parent 04bbf38 commit bbc3619
Show file tree
Hide file tree
Showing 14 changed files with 372 additions and 349 deletions.
57 changes: 26 additions & 31 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

from vllm import SamplingParams
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup

Expand All @@ -27,10 +28,7 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
},
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
Expand Down Expand Up @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

inputs = {
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(decoder_prompt_tokens,
prompt=decoder_prompt_str),
"encoder": token_inputs(encoder_prompt_tokens,
prompt=encoder_prompt_str),
}

decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
inputs=inputs["decoder"],
block_size=block_size)

encoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=False)
inputs=inputs["encoder"],
block_size=block_size)

seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of),
Expand Down Expand Up @@ -108,7 +104,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids},
inputs=token_inputs(prompt_token_ids),
block_size=16,
)

Expand Down Expand Up @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(

prompt_token_ids = [0] * seq_prompt_len

inputs = {
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(prompt_token_ids),
"encoder": token_inputs(prompt_token_ids),
}

seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16,
from_decoder_prompt=True)
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs=inputs["decoder"],
block_size=16,
)

for i in range(output_len):
seq.append_token_id(
Expand All @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
seqs.append(seq)

# Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs,
block_size=16,
from_decoder_prompt=False)
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs["encoder"],
block_size=16,
)

return SequenceGroup(request_id=request_id,
seqs=seqs,
Expand Down
3 changes: 2 additions & 1 deletion tests/engine/output_processor/test_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import PreTrainedTokenizer

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, Sequence, SequenceStatus

Expand All @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
"""
seq = Sequence(
seq_id=0,
inputs={"prompt_token_ids": []},
inputs=token_inputs([]),
block_size=16,
eos_token_id=eos_token_id,
)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest

from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids,
prompt=prompt),
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
Expand Down
6 changes: 2 additions & 4 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
Expand Down Expand Up @@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
block_size=16,
)

Expand Down
51 changes: 24 additions & 27 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Set, Type, Union, cast, overload

import torch
from typing_extensions import TypeIs, TypeVar
from typing_extensions import TypeVar

import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand All @@ -29,9 +29,9 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
Expand Down Expand Up @@ -638,7 +638,7 @@ def _verify_args(self) -> None:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
Expand Down Expand Up @@ -669,18 +669,19 @@ def _add_processed_request(
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None

seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)

encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))

# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
Expand Down Expand Up @@ -874,7 +875,7 @@ def _validate_token_prompt(self, prompt: PromptType,
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
Expand All @@ -884,10 +885,6 @@ def _validate_token_prompt(self, prompt: PromptType,
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))

@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt

def _create_sequence_group_with_sampling(
self,
request_id: str,
Expand Down Expand Up @@ -1978,17 +1975,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()

def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs],
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_ids = inputs.get("prompt_token_ids")
prompt_inputs = inputs

prompt_ids = prompt_inputs.get("prompt_token_ids")

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
Expand Down
23 changes: 16 additions & 7 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from typing import AsyncGenerator, List, Mapping, Optional

from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -60,7 +61,7 @@ def generate(

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: PromptType,
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
Expand All @@ -76,11 +77,19 @@ async def beam_search(
tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)

(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)

prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

tokenized_length = len(prompt_token_ids)

sort_beams_key = create_sort_beams_key_function(
Expand Down
11 changes: 6 additions & 5 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry

INPUT_REGISTRY = InputRegistry()
Expand All @@ -22,9 +22,10 @@
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
Expand Down
Loading

0 comments on commit bbc3619

Please sign in to comment.