Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a teminal chat interface to gptfast version #72

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions gptfast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
```

## Generate Text
## Chat Interface

### Running the Chat

To start the chat interface, run:

```bash
python -m gptfast.chat
```

### Available Commands

The chat interface supports the following commands:

- `help` - Display all available commands
- `quit` - Exit the chat
- `reset` - Clear the chat history
- `image` - Start a chat with an image (supports local paths and URLs)

### Examples

Basic chat:
```bash
You: Hello! Who are you?
Assistant: I am Aria, an AI assistant...

You: What can you help me with?
Assistant: I can help you with various tasks...
```

Chat with images:
```bash
python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?"
You: image
Enter image path or URL: https://example.com/cat.jpg
Enter your message about the image: What do you see in this image?
Assistant: I can see a cat...
```
121 changes: 121 additions & 0 deletions gptfast/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import List, Optional

import requests
from generate import GenerationConfig, Generator, ModelConfig
from PIL import Image


class ChatMessage:
def __init__(self, role: str, content: str, image_path: Optional[str] = None):
self.role = role
self.content = content
self.image_path = image_path


class AriaChat:
def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig):
self.generator = Generator(model_config, generation_config)
self.history: List[ChatMessage] = []

def add_message(self, role: str, content: str, image_path: Optional[str] = None):
"""Add a message to the chat history."""
self.history.append(ChatMessage(role, content, image_path))

def format_prompt(self) -> tuple[str, Optional[str]]:
"""Format the chat history into a prompt for the model."""
messages = []
images = []
for msg in self.history:
content = []
if msg.image_path:
content.append({"text": None, "type": "image"})
images.append(msg.image_path)
content.append({"text": msg.content, "type": "text"})
messages.append({"role": msg.role, "content": content})

processed_images = []
for image in images:
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw)
else:
image = Image.open(image)
image = image.convert("RGB")
processed_images.append(image)
return messages, processed_images

def chat(self, message: str, image_path: Optional[str] = None) -> str:
"""Send a message and get a response."""
self.add_message("user", message, image_path)
messages, image = self.format_prompt()

response = self.generator.generate(messages, image)

# Extract the assistant's response from the full generated text
assistant_message = response.split("<|assistant|>")[-1].strip()
# Remove the end token if present
for stop_string in self.generator.generation_config.stop_strings:
assistant_message = assistant_message.replace(stop_string, "").strip()

self.add_message("assistant", assistant_message)
return assistant_message

def reset(self):
"""Clear the chat history."""
self.history = []


if __name__ == "__main__":
from pathlib import Path

from gptfast.chat import AriaChat
from gptfast.generate import GenerationConfig, ModelConfig

model_config = ModelConfig(
checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"),
compile=True,
)
generation_config = GenerationConfig(
max_new_tokens=4096,
top_k=40,
temperature=0.8,
cache_size=8192,
)

chat = AriaChat(model_config, generation_config)

# Add welcome message and command instructions
print("\n=== AriaChat Terminal Interface ===")
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
print("\nType your message or command to begin...")

while True:
user_input = input("\n> You: ").strip()

if user_input.lower() == "quit":
break
elif user_input.lower() == "help":
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
continue
elif user_input.lower() == "reset":
chat.reset()
print("Chat history cleared.")
continue
elif user_input.lower() == "image":
image_path = input("Enter image path or URL: ").strip()
message = input("Enter your message about the image: ").strip()
response = chat.chat(message, image_path)
else:
response = chat.chat(user_input)

print(f"\n> Aria: {response}")

print("\nGoodbye!")
106 changes: 42 additions & 64 deletions gptfast/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
import sys
import time
from pathlib import Path
Expand All @@ -13,7 +14,7 @@
from model import Aria, ModelArgs, Transformer, prepare_inputs_for_model
from PIL import Image
from torch.nn.attention import SDPBackend
from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoProcessor


def device_sync(device):
Expand Down Expand Up @@ -172,7 +173,7 @@ def generate(
**sampling_kwargs,
)

seq = torch.cat((seq[: T + 1], *generated_tokens))
seq = torch.cat(generated_tokens)

return seq

Expand Down Expand Up @@ -206,7 +207,7 @@ def recommended_inductor_config_setter():
torch.set_float32_matmul_precision("high")


def load_model_and_tokenizer(checkpoint_path, device, precision):
def load_model_and_processor(checkpoint_path, device, precision):
print(f"Using device={device}")
print("Loading model ...")
t0 = time.time()
Expand All @@ -216,67 +217,34 @@ def load_model_and_tokenizer(checkpoint_path, device, precision):
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer_path = checkpoint_path.parent
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=False, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)

return model, tokenizer, processor


def prepare_image_inputs(image_path, prompt, processor, precision):
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": prompt, "type": "text"},
],
}
]

image = Image.open(
requests.get(image_path, stream=True).raw
if image_path.startswith(("http://", "https://"))
else image_path
)

text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
del inputs["attention_mask"]
inputs["pixel_values"] = inputs["pixel_values"].to(precision)
return inputs


def prepare_text_inputs(prompt, tokenizer):
messages = [{"role": "user", "content": [{"text": prompt, "type": "text"}]}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return {
"input_ids": tokenizer(text, return_tensors="pt").input_ids.to(torch.int32),
"pixel_values": None,
"pixel_mask": None,
}
return model, processor


def setup_model_compilation(
model, compile, compile_prefill, apply_regional_compilation
model, compile, compile_prefill, apply_regional_compilation, device
):
print("Compiling model...")
t0 = time.time()
if apply_regional_compilation:
print("Compiling Model")
for layer in model.llm.layers:
layer.compile()

if compile:
print("Compiling Model")
global decode_one_token, prefill
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
)
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

# warmup
for _ in range(3):
input_ids = torch.tensor([1] * random.randint(10, 100), device=device)
generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5)
print(f"Compilation done in {time.time() - t0:.02f} seconds")


class GenerationConfig:
"""Configuration class for text generation parameters."""
Expand Down Expand Up @@ -325,14 +293,13 @@ def __init__(self, model_config: ModelConfig, generation_config: GenerationConfi
self.model_config = model_config
self.generation_config = generation_config
self.model = None
self.tokenizer = None
self.processor = None

self._setup_model()

def _setup_model(self):
"""Initialize model, tokenizer and processor."""
self.model, self.tokenizer, self.processor = load_model_and_tokenizer(
self.model, self.processor = load_model_and_processor(
self.model_config.checkpoint_path,
self.model_config.device,
self.model_config.precision,
Expand All @@ -342,28 +309,29 @@ def _setup_model(self):
self.model_config.compile,
self.model_config.compile_prefill,
self.model_config.apply_regional_compilation,
self.model_config.device,
)

def generate(self, prompt: str, image_path: Optional[str] = None) -> str:
"""Generate text from prompt and optional image."""
inputs = (
prepare_image_inputs(
image_path, prompt, self.processor, self.model_config.precision
)
if image_path
else prepare_text_inputs(prompt, self.tokenizer)
)
inputs = {
k: v.to(self.model_config.device) if v is not None else v
for k, v in inputs.items()
}
def generate(
self, messages: list[dict], image: Optional[Image.Image] = None
) -> str:
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = self.processor(text=text, images=image, return_tensors="pt")
del inputs["attention_mask"]
for k, v in inputs.items():
if k == "pixel_values":
inputs[k] = v.to(self.model_config.precision).to(
self.model_config.device
)
else:
inputs[k] = v.to(self.model_config.device)

def early_stop_generation(tokens):
# This is not efficient, but it works
for stop_string in self.generation_config.stop_strings:

token_list = torch.cat(tokens)
decoded_string = self.tokenizer.decode(token_list)
decoded_string = self.processor.tokenizer.decode(token_list)
if decoded_string.endswith(stop_string):
return True
return False
Expand All @@ -379,7 +347,7 @@ def early_stop_generation(tokens):
callback=early_stop_generation,
)

return self.tokenizer.decode(output)
return self.processor.tokenizer.decode(output)


if __name__ == "__main__":
Expand All @@ -394,4 +362,14 @@ def early_stop_generation(tokens):
generator = Generator(model_config, generation_config)

image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
print(generator.generate("describe the image", image_path))
image = Image.open(requests.get(image_path, stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "describe the image", "type": "text"},
],
},
]
print(generator.generate(messages, image))
Loading