Skip to content

Commit

Permalink
Update benchmark_throughput.py to support image input
Browse files Browse the repository at this point in the history
This reverts commit eb6e01b.

Signed-off-by: Linkun Chen <[email protected]>
  • Loading branch information
Linkun Chen committed Nov 4, 2024
1 parent 1dc8821 commit 4701791
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import uvloop
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
Expand Down Expand Up @@ -38,12 +39,31 @@ class SampleRequest:
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[SampleRequest]:
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.
Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special tokens to add
Returns:
The formatted prompt string with appropriate special tokens for the model
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -52,16 +72,12 @@ def sample_requests(
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
for data in dataset:
if len(filtered_dataset) == num_requests:
break

Expand All @@ -85,13 +101,11 @@ def sample_requests(
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
Expand Down Expand Up @@ -119,7 +133,9 @@ def run_vllm(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
Expand Down

0 comments on commit 4701791

Please sign in to comment.