Skip to content

Commit

Permalink
feat(gptfast): benchmark the gptfast version
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 14, 2024
1 parent 62b9847 commit 625f11b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 15 deletions.
13 changes: 13 additions & 0 deletions gptfast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
```

## Benchmark

```bash
python benchmark.py --compile
```

### Performance Results (Single H100 GPU)

| Mode | Performance (tokens/s) |
|---------|----------------------:|
| Base | 25.2 |
| Compile | 130.0 |

## Chat Interface

### Running the Chat
Expand Down
83 changes: 83 additions & 0 deletions gptfast/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import time
from pathlib import Path
from statistics import mean, stdev

import requests
from generate import GenerationConfig, Generator, ModelConfig
from PIL import Image


def run_benchmark(
generator: Generator, messages: list[dict], image: Image.Image, num_runs: int = 5
):
"""Run multimodal generation benchmark."""
# Warmup runs
for _ in range(2):
generator.generate(messages, image)

# Benchmark runs
latencies = []
token_counts = []

for i in range(num_runs):
print(f"Running benchmark {i+1}/{num_runs}")
start_time = time.perf_counter()
output = generator.generate(messages, image, detokenize=False)
end_time = time.perf_counter()

latencies.append(end_time - start_time)
token_counts.append(len(output))

results = {
"mean_latency": mean(latencies),
"std_latency": stdev(latencies) if len(latencies) > 1 else 0,
"mean_tokens": mean(token_counts),
"std_tokens": stdev(token_counts) if len(token_counts) > 1 else 0,
"tokens_per_second": mean(token_counts) / mean(latencies),
}

print("\nBenchmark Results:")
print(
f"Average Latency: {results['mean_latency']:.2f}s (±{results['std_latency']:.2f}s)"
)
print(
f"Average Tokens: {results['mean_tokens']:.1f}{results['std_tokens']:.1f})"
)
print(f"Tokens per Second: {results['tokens_per_second']:.1f}")

return results


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--compile", action="store_true", help="Enable model compilation"
)
args = parser.parse_args()

model_config = ModelConfig(
checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"),
compile=args.compile,
)
generation_config = GenerationConfig(
max_new_tokens=200, top_k=200, temperature=0.8, stop_strings=None
)
generator = Generator(model_config, generation_config)

# Load test image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
image = Image.open(requests.get(image_url, stream=True).raw)

messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "Describe this image.", "type": "text"},
],
},
]

run_benchmark(generator, messages, image)
1 change: 1 addition & 0 deletions gptfast/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def reset(self):
top_k=40,
temperature=0.8,
cache_size=8192,
stop_strings=["<|im_end|>"],
)

chat = AriaChat(model_config, generation_config)
Expand Down
29 changes: 14 additions & 15 deletions gptfast/generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
import sys
import time
from pathlib import Path
Expand Down Expand Up @@ -225,8 +224,7 @@ def load_model_and_processor(checkpoint_path, device, precision):
def setup_model_compilation(
model, compile, compile_prefill, apply_regional_compilation, device
):
print("Compiling model...")
t0 = time.time()
recommended_inductor_config_setter()
if apply_regional_compilation:
for layer in model.llm.layers:
layer.compile()
Expand All @@ -239,12 +237,6 @@ def setup_model_compilation(
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

# warmup
for _ in range(3):
input_ids = torch.tensor([1] * random.randint(10, 100), device=device)
generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5)
print(f"Compilation done in {time.time() - t0:.02f} seconds")


class GenerationConfig:
"""Configuration class for text generation parameters."""
Expand All @@ -263,7 +255,7 @@ def __init__(
self.temperature = temperature
self.cache_size = cache_size
self.linear_causal_mask = linear_causal_mask
self.stop_strings = stop_strings or ["<|im_end|>"]
self.stop_strings = stop_strings


class ModelConfig:
Expand Down Expand Up @@ -313,7 +305,10 @@ def _setup_model(self):
)

def generate(
self, messages: list[dict], image: Optional[Image.Image] = None
self,
messages: list[dict],
image: Optional[Image.Image] = None,
detokenize: bool = True,
) -> str:
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = self.processor(text=text, images=image, return_tensors="pt")
Expand All @@ -327,6 +322,9 @@ def generate(
inputs[k] = v.to(self.model_config.device)

def early_stop_generation(tokens):
if self.generation_config.stop_strings is None:
return False

# This is not efficient, but it works
for stop_string in self.generation_config.stop_strings:

Expand All @@ -347,17 +345,18 @@ def early_stop_generation(tokens):
callback=early_stop_generation,
)

return self.processor.tokenizer.decode(output)
if detokenize:
return self.processor.tokenizer.decode(output)
else:
return output


if __name__ == "__main__":
model_config = ModelConfig(
checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"),
)
generation_config = GenerationConfig(
max_new_tokens=500,
top_k=200,
temperature=0.8,
max_new_tokens=500, top_k=200, temperature=0.8, stop_strings=["<|im_end|>"]
)
generator = Generator(model_config, generation_config)

Expand Down

0 comments on commit 625f11b

Please sign in to comment.