diff --git a/gptfast/README.md b/gptfast/README.md index c4c59e7..3a6fb56 100644 --- a/gptfast/README.md +++ b/gptfast/README.md @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ``` -## Generate Text +## Chat Interface +### Running the Chat + +To start the chat interface, run: + +```bash +python -m gptfast.chat +``` + +### Available Commands + +The chat interface supports the following commands: + +- `help` - Display all available commands +- `quit` - Exit the chat +- `reset` - Clear the chat history +- `image` - Start a chat with an image (supports local paths and URLs) + +### Examples + +Basic chat: +```bash +You: Hello! Who are you? +Assistant: I am Aria, an AI assistant... + +You: What can you help me with? +Assistant: I can help you with various tasks... +``` + +Chat with images: ```bash -python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?" +You: image +Enter image path or URL: https://example.com/cat.jpg +Enter your message about the image: What do you see in this image? +Assistant: I can see a cat... ``` \ No newline at end of file diff --git a/gptfast/chat.py b/gptfast/chat.py new file mode 100644 index 0000000..aac4b1b --- /dev/null +++ b/gptfast/chat.py @@ -0,0 +1,121 @@ +from typing import List, Optional + +import requests +from generate import GenerationConfig, Generator, ModelConfig +from PIL import Image + + +class ChatMessage: + def __init__(self, role: str, content: str, image_path: Optional[str] = None): + self.role = role + self.content = content + self.image_path = image_path + + +class AriaChat: + def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig): + self.generator = Generator(model_config, generation_config) + self.history: List[ChatMessage] = [] + + def add_message(self, role: str, content: str, image_path: Optional[str] = None): + """Add a message to the chat history.""" + self.history.append(ChatMessage(role, content, image_path)) + + def format_prompt(self) -> tuple[str, Optional[str]]: + """Format the chat history into a prompt for the model.""" + messages = [] + images = [] + for msg in self.history: + content = [] + if msg.image_path: + content.append({"text": None, "type": "image"}) + images.append(msg.image_path) + content.append({"text": msg.content, "type": "text"}) + messages.append({"role": msg.role, "content": content}) + + processed_images = [] + for image in images: + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = Image.open(requests.get(image, stream=True).raw) + else: + image = Image.open(image) + image = image.convert("RGB") + processed_images.append(image) + return messages, processed_images + + def chat(self, message: str, image_path: Optional[str] = None) -> str: + """Send a message and get a response.""" + self.add_message("user", message, image_path) + messages, image = self.format_prompt() + + response = self.generator.generate(messages, image) + + # Extract the assistant's response from the full generated text + assistant_message = response.split("<|assistant|>")[-1].strip() + # Remove the end token if present + for stop_string in self.generator.generation_config.stop_strings: + assistant_message = assistant_message.replace(stop_string, "").strip() + + self.add_message("assistant", assistant_message) + return assistant_message + + def reset(self): + """Clear the chat history.""" + self.history = [] + + +if __name__ == "__main__": + from pathlib import Path + + from gptfast.chat import AriaChat + from gptfast.generate import GenerationConfig, ModelConfig + + model_config = ModelConfig( + checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), + compile=True, + ) + generation_config = GenerationConfig( + max_new_tokens=4096, + top_k=40, + temperature=0.8, + cache_size=8192, + ) + + chat = AriaChat(model_config, generation_config) + + # Add welcome message and command instructions + print("\n=== AriaChat Terminal Interface ===") + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + print("\nType your message or command to begin...") + + while True: + user_input = input("\n> You: ").strip() + + if user_input.lower() == "quit": + break + elif user_input.lower() == "help": + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + continue + elif user_input.lower() == "reset": + chat.reset() + print("Chat history cleared.") + continue + elif user_input.lower() == "image": + image_path = input("Enter image path or URL: ").strip() + message = input("Enter your message about the image: ").strip() + response = chat.chat(message, image_path) + else: + response = chat.chat(user_input) + + print(f"\n> Aria: {response}") + + print("\nGoodbye!") diff --git a/gptfast/generate.py b/gptfast/generate.py index dc3c676..637ea90 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +import random import sys import time from pathlib import Path @@ -13,7 +14,7 @@ from model import Aria, ModelArgs, Transformer, prepare_inputs_for_model from PIL import Image from torch.nn.attention import SDPBackend -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoProcessor def device_sync(device): @@ -172,7 +173,7 @@ def generate( **sampling_kwargs, ) - seq = torch.cat((seq[: T + 1], *generated_tokens)) + seq = torch.cat(generated_tokens) return seq @@ -206,7 +207,7 @@ def recommended_inductor_config_setter(): torch.set_float32_matmul_precision("high") -def load_model_and_tokenizer(checkpoint_path, device, precision): +def load_model_and_processor(checkpoint_path, device, precision): print(f"Using device={device}") print("Loading model ...") t0 = time.time() @@ -216,60 +217,21 @@ def load_model_and_tokenizer(checkpoint_path, device, precision): print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer_path = checkpoint_path.parent - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, use_fast=False, trust_remote_code=True - ) processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True) - return model, tokenizer, processor - - -def prepare_image_inputs(image_path, prompt, processor, precision): - messages = [ - { - "role": "user", - "content": [ - {"text": None, "type": "image"}, - {"text": prompt, "type": "text"}, - ], - } - ] - - image = Image.open( - requests.get(image_path, stream=True).raw - if image_path.startswith(("http://", "https://")) - else image_path - ) - - text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor(text=text, images=image, return_tensors="pt") - del inputs["attention_mask"] - inputs["pixel_values"] = inputs["pixel_values"].to(precision) - return inputs - - -def prepare_text_inputs(prompt, tokenizer): - messages = [{"role": "user", "content": [{"text": prompt, "type": "text"}]}] - text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - return { - "input_ids": tokenizer(text, return_tensors="pt").input_ids.to(torch.int32), - "pixel_values": None, - "pixel_mask": None, - } + return model, processor def setup_model_compilation( - model, compile, compile_prefill, apply_regional_compilation + model, compile, compile_prefill, apply_regional_compilation, device ): + print("Compiling model...") + t0 = time.time() if apply_regional_compilation: - print("Compiling Model") for layer in model.llm.layers: layer.compile() if compile: - print("Compiling Model") global decode_one_token, prefill decode_one_token = torch.compile( decode_one_token, mode="reduce-overhead", fullgraph=True @@ -277,6 +239,12 @@ def setup_model_compilation( if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + # warmup + for _ in range(3): + input_ids = torch.tensor([1] * random.randint(10, 100), device=device) + generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5) + print(f"Compilation done in {time.time() - t0:.02f} seconds") + class GenerationConfig: """Configuration class for text generation parameters.""" @@ -325,14 +293,13 @@ def __init__(self, model_config: ModelConfig, generation_config: GenerationConfi self.model_config = model_config self.generation_config = generation_config self.model = None - self.tokenizer = None self.processor = None self._setup_model() def _setup_model(self): """Initialize model, tokenizer and processor.""" - self.model, self.tokenizer, self.processor = load_model_and_tokenizer( + self.model, self.processor = load_model_and_processor( self.model_config.checkpoint_path, self.model_config.device, self.model_config.precision, @@ -342,28 +309,29 @@ def _setup_model(self): self.model_config.compile, self.model_config.compile_prefill, self.model_config.apply_regional_compilation, + self.model_config.device, ) - def generate(self, prompt: str, image_path: Optional[str] = None) -> str: - """Generate text from prompt and optional image.""" - inputs = ( - prepare_image_inputs( - image_path, prompt, self.processor, self.model_config.precision - ) - if image_path - else prepare_text_inputs(prompt, self.tokenizer) - ) - inputs = { - k: v.to(self.model_config.device) if v is not None else v - for k, v in inputs.items() - } + def generate( + self, messages: list[dict], image: Optional[Image.Image] = None + ) -> str: + text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = self.processor(text=text, images=image, return_tensors="pt") + del inputs["attention_mask"] + for k, v in inputs.items(): + if k == "pixel_values": + inputs[k] = v.to(self.model_config.precision).to( + self.model_config.device + ) + else: + inputs[k] = v.to(self.model_config.device) def early_stop_generation(tokens): # This is not efficient, but it works for stop_string in self.generation_config.stop_strings: token_list = torch.cat(tokens) - decoded_string = self.tokenizer.decode(token_list) + decoded_string = self.processor.tokenizer.decode(token_list) if decoded_string.endswith(stop_string): return True return False @@ -379,7 +347,7 @@ def early_stop_generation(tokens): callback=early_stop_generation, ) - return self.tokenizer.decode(output) + return self.processor.tokenizer.decode(output) if __name__ == "__main__": @@ -394,4 +362,14 @@ def early_stop_generation(tokens): generator = Generator(model_config, generation_config) image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" - print(generator.generate("describe the image", image_path)) + image = Image.open(requests.get(image_path, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"text": None, "type": "image"}, + {"text": "describe the image", "type": "text"}, + ], + }, + ] + print(generator.generate(messages, image))