Skip to content

Commit

Permalink
Merge pull request #72 from rhymes-ai/gptfast_chat
Browse files Browse the repository at this point in the history
add a teminal chat interface to gptfast version
  • Loading branch information
xffxff authored Nov 14, 2024
2 parents 02ecc11 + f8f0511 commit 6b3fb24
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 66 deletions.
36 changes: 34 additions & 2 deletions gptfast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
```

## Generate Text
## Chat Interface

### Running the Chat

To start the chat interface, run:

```bash
python -m gptfast.chat
```

### Available Commands

The chat interface supports the following commands:

- `help` - Display all available commands
- `quit` - Exit the chat
- `reset` - Clear the chat history
- `image` - Start a chat with an image (supports local paths and URLs)

### Examples

Basic chat:
```bash
You: Hello! Who are you?
Assistant: I am Aria, an AI assistant...

You: What can you help me with?
Assistant: I can help you with various tasks...
```

Chat with images:
```bash
python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?"
You: image
Enter image path or URL: https://example.com/cat.jpg
Enter your message about the image: What do you see in this image?
Assistant: I can see a cat...
```
121 changes: 121 additions & 0 deletions gptfast/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import List, Optional

import requests
from generate import GenerationConfig, Generator, ModelConfig
from PIL import Image


class ChatMessage:
def __init__(self, role: str, content: str, image_path: Optional[str] = None):
self.role = role
self.content = content
self.image_path = image_path


class AriaChat:
def __init__(self, model_config: ModelConfig, generation_config: GenerationConfig):
self.generator = Generator(model_config, generation_config)
self.history: List[ChatMessage] = []

def add_message(self, role: str, content: str, image_path: Optional[str] = None):
"""Add a message to the chat history."""
self.history.append(ChatMessage(role, content, image_path))

def format_prompt(self) -> tuple[str, Optional[str]]:
"""Format the chat history into a prompt for the model."""
messages = []
images = []
for msg in self.history:
content = []
if msg.image_path:
content.append({"text": None, "type": "image"})
images.append(msg.image_path)
content.append({"text": msg.content, "type": "text"})
messages.append({"role": msg.role, "content": content})

processed_images = []
for image in images:
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw)
else:
image = Image.open(image)
image = image.convert("RGB")
processed_images.append(image)
return messages, processed_images

def chat(self, message: str, image_path: Optional[str] = None) -> str:
"""Send a message and get a response."""
self.add_message("user", message, image_path)
messages, image = self.format_prompt()

response = self.generator.generate(messages, image)

# Extract the assistant's response from the full generated text
assistant_message = response.split("<|assistant|>")[-1].strip()
# Remove the end token if present
for stop_string in self.generator.generation_config.stop_strings:
assistant_message = assistant_message.replace(stop_string, "").strip()

self.add_message("assistant", assistant_message)
return assistant_message

def reset(self):
"""Clear the chat history."""
self.history = []


if __name__ == "__main__":
from pathlib import Path

from gptfast.chat import AriaChat
from gptfast.generate import GenerationConfig, ModelConfig

model_config = ModelConfig(
checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"),
compile=True,
)
generation_config = GenerationConfig(
max_new_tokens=4096,
top_k=40,
temperature=0.8,
cache_size=8192,
)

chat = AriaChat(model_config, generation_config)

# Add welcome message and command instructions
print("\n=== AriaChat Terminal Interface ===")
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
print("\nType your message or command to begin...")

while True:
user_input = input("\n> You: ").strip()

if user_input.lower() == "quit":
break
elif user_input.lower() == "help":
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
continue
elif user_input.lower() == "reset":
chat.reset()
print("Chat history cleared.")
continue
elif user_input.lower() == "image":
image_path = input("Enter image path or URL: ").strip()
message = input("Enter your message about the image: ").strip()
response = chat.chat(message, image_path)
else:
response = chat.chat(user_input)

print(f"\n> Aria: {response}")

print("\nGoodbye!")
106 changes: 42 additions & 64 deletions gptfast/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
import sys
import time
from pathlib import Path
Expand All @@ -13,7 +14,7 @@
from model import Aria, ModelArgs, Transformer, prepare_inputs_for_model
from PIL import Image
from torch.nn.attention import SDPBackend
from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoProcessor


def device_sync(device):
Expand Down Expand Up @@ -172,7 +173,7 @@ def generate(
**sampling_kwargs,
)

seq = torch.cat((seq[: T + 1], *generated_tokens))
seq = torch.cat(generated_tokens)

return seq

Expand Down Expand Up @@ -206,7 +207,7 @@ def recommended_inductor_config_setter():
torch.set_float32_matmul_precision("high")


def load_model_and_tokenizer(checkpoint_path, device, precision):
def load_model_and_processor(checkpoint_path, device, precision):
print(f"Using device={device}")
print("Loading model ...")
t0 = time.time()
Expand All @@ -216,67 +217,34 @@ def load_model_and_tokenizer(checkpoint_path, device, precision):
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer_path = checkpoint_path.parent
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=False, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)

return model, tokenizer, processor


def prepare_image_inputs(image_path, prompt, processor, precision):
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": prompt, "type": "text"},
],
}
]

image = Image.open(
requests.get(image_path, stream=True).raw
if image_path.startswith(("http://", "https://"))
else image_path
)

text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
del inputs["attention_mask"]
inputs["pixel_values"] = inputs["pixel_values"].to(precision)
return inputs


def prepare_text_inputs(prompt, tokenizer):
messages = [{"role": "user", "content": [{"text": prompt, "type": "text"}]}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return {
"input_ids": tokenizer(text, return_tensors="pt").input_ids.to(torch.int32),
"pixel_values": None,
"pixel_mask": None,
}
return model, processor


def setup_model_compilation(
model, compile, compile_prefill, apply_regional_compilation
model, compile, compile_prefill, apply_regional_compilation, device
):
print("Compiling model...")
t0 = time.time()
if apply_regional_compilation:
print("Compiling Model")
for layer in model.llm.layers:
layer.compile()

if compile:
print("Compiling Model")
global decode_one_token, prefill
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
)
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

# warmup
for _ in range(3):
input_ids = torch.tensor([1] * random.randint(10, 100), device=device)
generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5)
print(f"Compilation done in {time.time() - t0:.02f} seconds")


class GenerationConfig:
"""Configuration class for text generation parameters."""
Expand Down Expand Up @@ -325,14 +293,13 @@ def __init__(self, model_config: ModelConfig, generation_config: GenerationConfi
self.model_config = model_config
self.generation_config = generation_config
self.model = None
self.tokenizer = None
self.processor = None

self._setup_model()

def _setup_model(self):
"""Initialize model, tokenizer and processor."""
self.model, self.tokenizer, self.processor = load_model_and_tokenizer(
self.model, self.processor = load_model_and_processor(
self.model_config.checkpoint_path,
self.model_config.device,
self.model_config.precision,
Expand All @@ -342,28 +309,29 @@ def _setup_model(self):
self.model_config.compile,
self.model_config.compile_prefill,
self.model_config.apply_regional_compilation,
self.model_config.device,
)

def generate(self, prompt: str, image_path: Optional[str] = None) -> str:
"""Generate text from prompt and optional image."""
inputs = (
prepare_image_inputs(
image_path, prompt, self.processor, self.model_config.precision
)
if image_path
else prepare_text_inputs(prompt, self.tokenizer)
)
inputs = {
k: v.to(self.model_config.device) if v is not None else v
for k, v in inputs.items()
}
def generate(
self, messages: list[dict], image: Optional[Image.Image] = None
) -> str:
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = self.processor(text=text, images=image, return_tensors="pt")
del inputs["attention_mask"]
for k, v in inputs.items():
if k == "pixel_values":
inputs[k] = v.to(self.model_config.precision).to(
self.model_config.device
)
else:
inputs[k] = v.to(self.model_config.device)

def early_stop_generation(tokens):
# This is not efficient, but it works
for stop_string in self.generation_config.stop_strings:

token_list = torch.cat(tokens)
decoded_string = self.tokenizer.decode(token_list)
decoded_string = self.processor.tokenizer.decode(token_list)
if decoded_string.endswith(stop_string):
return True
return False
Expand All @@ -379,7 +347,7 @@ def early_stop_generation(tokens):
callback=early_stop_generation,
)

return self.tokenizer.decode(output)
return self.processor.tokenizer.decode(output)


if __name__ == "__main__":
Expand All @@ -394,4 +362,14 @@ def early_stop_generation(tokens):
generator = Generator(model_config, generation_config)

image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
print(generator.generate("describe the image", image_path))
image = Image.open(requests.get(image_path, stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "describe the image", "type": "text"},
],
},
]
print(generator.generate(messages, image))

0 comments on commit 6b3fb24

Please sign in to comment.