From 279f85da8ae81d90aded6d583f354bf68925df2e Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 07:19:15 +0000 Subject: [PATCH 1/5] refactor the interface of Generator --- gptfast/generate.py | 82 ++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 61 deletions(-) diff --git a/gptfast/generate.py b/gptfast/generate.py index dc3c676..da73e31 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -172,7 +172,7 @@ def generate( **sampling_kwargs, ) - seq = torch.cat((seq[: T + 1], *generated_tokens)) + seq = torch.cat(generated_tokens) return seq @@ -206,7 +206,7 @@ def recommended_inductor_config_setter(): torch.set_float32_matmul_precision("high") -def load_model_and_tokenizer(checkpoint_path, device, precision): +def load_model_and_processor(checkpoint_path, device, precision): print(f"Using device={device}") print("Loading model ...") t0 = time.time() @@ -216,49 +216,9 @@ def load_model_and_tokenizer(checkpoint_path, device, precision): print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer_path = checkpoint_path.parent - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, use_fast=False, trust_remote_code=True - ) processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True) - return model, tokenizer, processor - - -def prepare_image_inputs(image_path, prompt, processor, precision): - messages = [ - { - "role": "user", - "content": [ - {"text": None, "type": "image"}, - {"text": prompt, "type": "text"}, - ], - } - ] - - image = Image.open( - requests.get(image_path, stream=True).raw - if image_path.startswith(("http://", "https://")) - else image_path - ) - - text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor(text=text, images=image, return_tensors="pt") - del inputs["attention_mask"] - inputs["pixel_values"] = inputs["pixel_values"].to(precision) - return inputs - - -def prepare_text_inputs(prompt, tokenizer): - messages = [{"role": "user", "content": [{"text": prompt, "type": "text"}]}] - text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - return { - "input_ids": tokenizer(text, return_tensors="pt").input_ids.to(torch.int32), - "pixel_values": None, - "pixel_mask": None, - } - + return model, processor def setup_model_compilation( model, compile, compile_prefill, apply_regional_compilation @@ -325,14 +285,13 @@ def __init__(self, model_config: ModelConfig, generation_config: GenerationConfi self.model_config = model_config self.generation_config = generation_config self.model = None - self.tokenizer = None self.processor = None self._setup_model() def _setup_model(self): """Initialize model, tokenizer and processor.""" - self.model, self.tokenizer, self.processor = load_model_and_tokenizer( + self.model, self.processor = load_model_and_processor( self.model_config.checkpoint_path, self.model_config.device, self.model_config.precision, @@ -344,26 +303,20 @@ def _setup_model(self): self.model_config.apply_regional_compilation, ) - def generate(self, prompt: str, image_path: Optional[str] = None) -> str: - """Generate text from prompt and optional image.""" - inputs = ( - prepare_image_inputs( - image_path, prompt, self.processor, self.model_config.precision - ) - if image_path - else prepare_text_inputs(prompt, self.tokenizer) - ) - inputs = { - k: v.to(self.model_config.device) if v is not None else v - for k, v in inputs.items() - } + def generate(self, messages: list[dict], image: Optional[Image.Image] = None) -> str: + text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = self.processor(text=text, images=image, return_tensors="pt") + del inputs["attention_mask"] + inputs["pixel_values"] = inputs["pixel_values"].to(self.model_config.precision) + for k, v in inputs.items(): + inputs[k] = v.to(self.model_config.device) def early_stop_generation(tokens): # This is not efficient, but it works for stop_string in self.generation_config.stop_strings: token_list = torch.cat(tokens) - decoded_string = self.tokenizer.decode(token_list) + decoded_string = self.processor.tokenizer.decode(token_list) if decoded_string.endswith(stop_string): return True return False @@ -379,7 +332,7 @@ def early_stop_generation(tokens): callback=early_stop_generation, ) - return self.tokenizer.decode(output) + return self.processor.tokenizer.decode(output) if __name__ == "__main__": @@ -394,4 +347,11 @@ def early_stop_generation(tokens): generator = Generator(model_config, generation_config) image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" - print(generator.generate("describe the image", image_path)) + image = Image.open(requests.get(image_path, stream=True).raw) + messages = [ + { + "role": "user", + "content": [{"text": None, "type": "image"}, {"text": "describe the image", "type": "text"}], + }, + ] + print(generator.generate(messages, image)) From f9ef55fddc13df9218c58841532882b1be86eff2 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 07:42:15 +0000 Subject: [PATCH 2/5] fix: ensure pixel values are properly cast to precision --- gptfast/generate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gptfast/generate.py b/gptfast/generate.py index da73e31..b4ea6b9 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -307,9 +307,11 @@ def generate(self, messages: list[dict], image: Optional[Image.Image] = None) -> text = self.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = self.processor(text=text, images=image, return_tensors="pt") del inputs["attention_mask"] - inputs["pixel_values"] = inputs["pixel_values"].to(self.model_config.precision) for k, v in inputs.items(): - inputs[k] = v.to(self.model_config.device) + if k == "pixel_values": + inputs[k] = v.to(self.model_config.precision).to(self.model_config.device) + else: + inputs[k] = v.to(self.model_config.device) def early_stop_generation(tokens): # This is not efficient, but it works From 404dadbf9d909f0c8594b8a86a80f1ae38e01220 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 07:42:28 +0000 Subject: [PATCH 3/5] add chat support --- gptfast/chat.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 gptfast/chat.py diff --git a/gptfast/chat.py b/gptfast/chat.py new file mode 100644 index 0000000..8772084 --- /dev/null +++ b/gptfast/chat.py @@ -0,0 +1,92 @@ +from typing import List, Dict, Optional +from generate import Generator, ModelConfig, GenerationConfig +from PIL import Image +import requests +class ChatMessage: + def __init__(self, role: str, content: str, image_path: Optional[str] = None): + self.role = role + self.content = content + self.image_path = image_path + +class AriaChat: + def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig): + self.generator = Generator(model_config, generation_config) + self.history: List[ChatMessage] = [] + + def add_message(self, role: str, content: str, image_path: Optional[str] = None): + """Add a message to the chat history.""" + self.history.append(ChatMessage(role, content, image_path)) + + def format_prompt(self) -> tuple[str, Optional[str]]: + """Format the chat history into a prompt for the model.""" + messages = [] + images = [] + for msg in self.history: + content = [] + if msg.image_path: + content.append({"text": None, "type": "image"}) + images.append(msg.image_path) + content.append({"text": msg.content, "type": "text"}) + messages.append({"role": msg.role, "content": content}) + + processed_images = [] + for image in images: + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = Image.open(requests.get(image, stream=True).raw) + else: + image = Image.open(image) + image = image.convert("RGB") + processed_images.append(image) + return messages, processed_images + + def chat(self, message: str, image_path: Optional[str] = None) -> str: + """Send a message and get a response.""" + self.add_message("user", message, image_path) + messages, image = self.format_prompt() + print(f"{messages=}") + print(f"{image=}") + + response = self.generator.generate(messages, image) + + # Extract the assistant's response from the full generated text + assistant_message = response.split("<|assistant|>")[-1].strip() + # Remove the end token if present + for stop_string in self.generator.generation_config.stop_strings: + assistant_message = assistant_message.replace(stop_string, "").strip() + + self.add_message("assistant", assistant_message) + return assistant_message + + def reset(self): + """Clear the chat history.""" + self.history = [] + + +if __name__ == "__main__": + from pathlib import Path + from gptfast.generate import ModelConfig, GenerationConfig + from gptfast.chat import AriaChat + + model_config = ModelConfig( + checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), + ) + generation_config = GenerationConfig( + max_new_tokens=500, + top_k=200, + temperature=0.8, + ) + + chat = AriaChat(model_config, generation_config) + + # Chat without images + response = chat.chat("Hello! Who are you?") + print(response) + + # Chat with an image + image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" + response = chat.chat("Describe the image", image_path) + print(response) + + # Reset the chat + chat.reset() \ No newline at end of file From cce0223cc17ffe3bd378f8bef0d6a313161a61a1 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 07:42:49 +0000 Subject: [PATCH 4/5] make format happy --- gptfast/chat.py | 23 ++++++++++++++--------- gptfast/generate.py | 16 ++++++++++++---- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/gptfast/chat.py b/gptfast/chat.py index 8772084..916d5fb 100644 --- a/gptfast/chat.py +++ b/gptfast/chat.py @@ -1,13 +1,17 @@ -from typing import List, Dict, Optional -from generate import Generator, ModelConfig, GenerationConfig -from PIL import Image +from typing import List, Optional + import requests +from generate import GenerationConfig, Generator, ModelConfig +from PIL import Image + + class ChatMessage: def __init__(self, role: str, content: str, image_path: Optional[str] = None): self.role = role self.content = content self.image_path = image_path + class AriaChat: def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig): self.generator = Generator(model_config, generation_config) @@ -28,7 +32,7 @@ def format_prompt(self) -> tuple[str, Optional[str]]: images.append(msg.image_path) content.append({"text": msg.content, "type": "text"}) messages.append({"role": msg.role, "content": content}) - + processed_images = [] for image in images: if isinstance(image, str): @@ -46,15 +50,15 @@ def chat(self, message: str, image_path: Optional[str] = None) -> str: messages, image = self.format_prompt() print(f"{messages=}") print(f"{image=}") - + response = self.generator.generate(messages, image) - + # Extract the assistant's response from the full generated text assistant_message = response.split("<|assistant|>")[-1].strip() # Remove the end token if present for stop_string in self.generator.generation_config.stop_strings: assistant_message = assistant_message.replace(stop_string, "").strip() - + self.add_message("assistant", assistant_message) return assistant_message @@ -65,8 +69,9 @@ def reset(self): if __name__ == "__main__": from pathlib import Path - from gptfast.generate import ModelConfig, GenerationConfig + from gptfast.chat import AriaChat + from gptfast.generate import GenerationConfig, ModelConfig model_config = ModelConfig( checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), @@ -89,4 +94,4 @@ def reset(self): print(response) # Reset the chat - chat.reset() \ No newline at end of file + chat.reset() diff --git a/gptfast/generate.py b/gptfast/generate.py index b4ea6b9..e85e96b 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -13,7 +13,7 @@ from model import Aria, ModelArgs, Transformer, prepare_inputs_for_model from PIL import Image from torch.nn.attention import SDPBackend -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoProcessor def device_sync(device): @@ -220,6 +220,7 @@ def load_model_and_processor(checkpoint_path, device, precision): return model, processor + def setup_model_compilation( model, compile, compile_prefill, apply_regional_compilation ): @@ -303,13 +304,17 @@ def _setup_model(self): self.model_config.apply_regional_compilation, ) - def generate(self, messages: list[dict], image: Optional[Image.Image] = None) -> str: + def generate( + self, messages: list[dict], image: Optional[Image.Image] = None + ) -> str: text = self.processor.apply_chat_template(messages, add_generation_prompt=True) inputs = self.processor(text=text, images=image, return_tensors="pt") del inputs["attention_mask"] for k, v in inputs.items(): if k == "pixel_values": - inputs[k] = v.to(self.model_config.precision).to(self.model_config.device) + inputs[k] = v.to(self.model_config.precision).to( + self.model_config.device + ) else: inputs[k] = v.to(self.model_config.device) @@ -353,7 +358,10 @@ def early_stop_generation(tokens): messages = [ { "role": "user", - "content": [{"text": None, "type": "image"}, {"text": "describe the image", "type": "text"}], + "content": [ + {"text": None, "type": "image"}, + {"text": "describe the image", "type": "text"}, + ], }, ] print(generator.generate(messages, image)) From f8f05110c2e8bc332071ada1fee7a6386c0c6d70 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 14 Nov 2024 08:26:51 +0000 Subject: [PATCH 5/5] add terminal chat interface --- gptfast/README.md | 36 ++++++++++++++++++++++++++++-- gptfast/chat.py | 54 ++++++++++++++++++++++++++++++++------------- gptfast/generate.py | 14 +++++++++--- 3 files changed, 84 insertions(+), 20 deletions(-) diff --git a/gptfast/README.md b/gptfast/README.md index c4c59e7..3a6fb56 100644 --- a/gptfast/README.md +++ b/gptfast/README.md @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ``` -## Generate Text +## Chat Interface +### Running the Chat + +To start the chat interface, run: + +```bash +python -m gptfast.chat +``` + +### Available Commands + +The chat interface supports the following commands: + +- `help` - Display all available commands +- `quit` - Exit the chat +- `reset` - Clear the chat history +- `image` - Start a chat with an image (supports local paths and URLs) + +### Examples + +Basic chat: +```bash +You: Hello! Who are you? +Assistant: I am Aria, an AI assistant... + +You: What can you help me with? +Assistant: I can help you with various tasks... +``` + +Chat with images: ```bash -python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?" +You: image +Enter image path or URL: https://example.com/cat.jpg +Enter your message about the image: What do you see in this image? +Assistant: I can see a cat... ``` \ No newline at end of file diff --git a/gptfast/chat.py b/gptfast/chat.py index 916d5fb..aac4b1b 100644 --- a/gptfast/chat.py +++ b/gptfast/chat.py @@ -48,8 +48,6 @@ def chat(self, message: str, image_path: Optional[str] = None) -> str: """Send a message and get a response.""" self.add_message("user", message, image_path) messages, image = self.format_prompt() - print(f"{messages=}") - print(f"{image=}") response = self.generator.generate(messages, image) @@ -75,23 +73,49 @@ def reset(self): model_config = ModelConfig( checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"), + compile=True, ) generation_config = GenerationConfig( - max_new_tokens=500, - top_k=200, + max_new_tokens=4096, + top_k=40, temperature=0.8, + cache_size=8192, ) chat = AriaChat(model_config, generation_config) - # Chat without images - response = chat.chat("Hello! Who are you?") - print(response) - - # Chat with an image - image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" - response = chat.chat("Describe the image", image_path) - print(response) - - # Reset the chat - chat.reset() + # Add welcome message and command instructions + print("\n=== AriaChat Terminal Interface ===") + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + print("\nType your message or command to begin...") + + while True: + user_input = input("\n> You: ").strip() + + if user_input.lower() == "quit": + break + elif user_input.lower() == "help": + print("\nAvailable commands:") + print(" 'help' - Show this help message") + print(" 'quit' - Exit the chat") + print(" 'reset' - Clear chat history") + print(" 'image' - Chat with an image") + continue + elif user_input.lower() == "reset": + chat.reset() + print("Chat history cleared.") + continue + elif user_input.lower() == "image": + image_path = input("Enter image path or URL: ").strip() + message = input("Enter your message about the image: ").strip() + response = chat.chat(message, image_path) + else: + response = chat.chat(user_input) + + print(f"\n> Aria: {response}") + + print("\nGoodbye!") diff --git a/gptfast/generate.py b/gptfast/generate.py index e85e96b..637ea90 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +import random import sys import time from pathlib import Path @@ -222,15 +223,15 @@ def load_model_and_processor(checkpoint_path, device, precision): def setup_model_compilation( - model, compile, compile_prefill, apply_regional_compilation + model, compile, compile_prefill, apply_regional_compilation, device ): + print("Compiling model...") + t0 = time.time() if apply_regional_compilation: - print("Compiling Model") for layer in model.llm.layers: layer.compile() if compile: - print("Compiling Model") global decode_one_token, prefill decode_one_token = torch.compile( decode_one_token, mode="reduce-overhead", fullgraph=True @@ -238,6 +239,12 @@ def setup_model_compilation( if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + # warmup + for _ in range(3): + input_ids = torch.tensor([1] * random.randint(10, 100), device=device) + generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5) + print(f"Compilation done in {time.time() - t0:.02f} seconds") + class GenerationConfig: """Configuration class for text generation parameters.""" @@ -302,6 +309,7 @@ def _setup_model(self): self.model_config.compile, self.model_config.compile_prefill, self.model_config.apply_regional_compilation, + self.model_config.device, ) def generate(