diff --git a/gptfast/README.md b/gptfast/README.md index c4c59e7..3a6fb56 100644 --- a/gptfast/README.md +++ b/gptfast/README.md @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ``` -## Generate Text +## Chat Interface +### Running the Chat + +To start the chat interface, run: + +```bash +python -m gptfast.chat +``` + +### Available Commands + +The chat interface supports the following commands: + +- `help` - Display all available commands +- `quit` - Exit the chat +- `reset` - Clear the chat history +- `image` - Start a chat with an image (supports local paths and URLs) + +### Examples + +Basic chat: +```bash +You: Hello! Who are you? +Assistant: I am Aria, an AI assistant... + +You: What can you help me with? +Assistant: I can help you with various tasks... +``` + +Chat with images: ```bash -python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?" +You: image +Enter image path or URL: https://example.com/cat.jpg +Enter your message about the image: What do you see in this image? +Assistant: I can see a cat... ``` \ No newline at end of file diff --git a/gptfast/chat.py b/gptfast/chat.py index 916d5fb..aac4b1b 100644 --- a/gptfast/chat.py +++ b/gptfast/chat.py @@ -48,8 +48,6 @@ def chat(self, message: str, image_path: Optional[str] = None) -> str: """Send a message and get a response.""" self.add_message("user", message, image_path) messages, image = self.format_prompt() - print(f"{messages=}") - print(f"{image=}") response = self.generator.generate(messages, image) @@ -75,23 +73,49 @@ def reset(self): model_config = ModelConfig( checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), + compile=True, ) generation_config = GenerationConfig( - max_new_tokens=500, - top_k=200, + max_new_tokens=4096, + top_k=40, temperature=0.8, + cache_size=8192, ) chat = AriaChat(model_config, generation_config) - # Chat without images - response = chat.chat("Hello! Who are you?") - print(response) - - # Chat with an image - image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" - response = chat.chat("Describe the image", image_path) - print(response) - - # Reset the chat - chat.reset() + # Add welcome message and command instructions + print("\n=== AriaChat Terminal Interface ===") + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + print("\nType your message or command to begin...") + + while True: + user_input = input("\n> You: ").strip() + + if user_input.lower() == "quit": + break + elif user_input.lower() == "help": + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + continue + elif user_input.lower() == "reset": + chat.reset() + print("Chat history cleared.") + continue + elif user_input.lower() == "image": + image_path = input("Enter image path or URL: ").strip() + message = input("Enter your message about the image: ").strip() + response = chat.chat(message, image_path) + else: + response = chat.chat(user_input) + + print(f"\n> Aria: {response}") + + print("\nGoodbye!") diff --git a/gptfast/generate.py b/gptfast/generate.py index e85e96b..637ea90 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +import random import sys import time from pathlib import Path @@ -222,15 +223,15 @@ def load_model_and_processor(checkpoint_path, device, precision): def setup_model_compilation( - model, compile, compile_prefill, apply_regional_compilation + model, compile, compile_prefill, apply_regional_compilation, device ): + print("Compiling model...") + t0 = time.time() if apply_regional_compilation: - print("Compiling Model") for layer in model.llm.layers: layer.compile() if compile: - print("Compiling Model") global decode_one_token, prefill decode_one_token = torch.compile( decode_one_token, mode="reduce-overhead", fullgraph=True @@ -238,6 +239,12 @@ def setup_model_compilation( if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + # warmup + for _ in range(3): + input_ids = torch.tensor([1] * random.randint(10, 100), device=device) + generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5) + print(f"Compilation done in {time.time() - t0:.02f} seconds") + class GenerationConfig: """Configuration class for text generation parameters.""" @@ -302,6 +309,7 @@ def _setup_model(self): self.model_config.compile, self.model_config.compile_prefill, self.model_config.apply_regional_compilation, + self.model_config.device, ) def generate(