diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 5647f913e1f9c..afa7dafc7d9ad 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -8,6 +8,7 @@ import torch import uvloop +from PIL import Image from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -38,12 +39,31 @@ class SampleRequest: multi_modal_data: Optional[MultiModalDataDict] = None -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int], -) -> List[SampleRequest]: +def _get_prompt_for_image_model(question: str, *, model: str) -> str: + """Prepend and append special tokens around the question to form a prompt. + + Args: + question: The input question text to wrap with special tokens + model: The name of the model being used, to determine which special tokens to add + + Returns: + The formatted prompt string with appropriate special tokens for the model + + Raises: + ValueError: If an unsupported model name is provided + """ + model = model.lower() + if "pixtral" in model: + return f"[INST]{question}\n[IMG][/INST]" + raise ValueError(f"Unsupported model {model}") + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset + num_requests: int = args.num_prompts + fixed_output_len: Optional[int] = args.output_len + model: str = args.model if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -52,16 +72,12 @@ def sample_requests( dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for i in range(len(dataset)): + for data in dataset: if len(filtered_dataset) == num_requests: break @@ -85,13 +101,11 @@ def sample_requests( prompt = _get_prompt_for_image_model(question=prompt, model=model) # Tokenize the prompts and completions. - prompt = dataset[i][0] prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -119,7 +133,9 @@ def run_vllm( prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: - prompts.append(TextPrompt(prompt=request.prompt)) + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) sampling_params.append( SamplingParams( n=n,