Skip to content

Commit

Permalink
[Bugfix] Prevent benchmark_throughput.py from using duplicated random…
Browse files Browse the repository at this point in the history
… prompts (#10753)
  • Loading branch information
mgoin authored Dec 3, 2024
1 parent 4c05edb commit 4433195
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,23 +294,36 @@ def main(args: argparse.Namespace):
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [
SampleRequest(prompt=prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len)
for _ in range(args.num_prompts)
]
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):
# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
for _ in range(args.input_len)
]
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = tokenizer.decode(candidate_ids)
tokenized_len = len(tokenizer.encode(candidate_prompt))

if tokenized_len == args.input_len:
break

# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len))
else:
requests = sample_requests(tokenizer, args)

Expand Down

0 comments on commit 4433195

Please sign in to comment.