Skip to content

Commit

Permalink
add terminal chat interface
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 14, 2024
1 parent cce0223 commit f8f0511
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 20 deletions.
36 changes: 34 additions & 2 deletions gptfast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,40 @@ python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
```

## Generate Text
## Chat Interface

### Running the Chat

To start the chat interface, run:

```bash
python -m gptfast.chat
```

### Available Commands

The chat interface supports the following commands:

- `help` - Display all available commands
- `quit` - Exit the chat
- `reset` - Clear the chat history
- `image` - Start a chat with an image (supports local paths and URLs)

### Examples

Basic chat:
```bash
You: Hello! Who are you?
Assistant: I am Aria, an AI assistant...

You: What can you help me with?
Assistant: I can help you with various tasks...
```

Chat with images:
```bash
python generate.py checkpoints/rhymes-ai/Aria/model.pth --compile --apply_regional_compilation --prompt "What is the meaning of life?"
You: image
Enter image path or URL: https://example.com/cat.jpg
Enter your message about the image: What do you see in this image?
Assistant: I can see a cat...
```
54 changes: 39 additions & 15 deletions gptfast/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def chat(self, message: str, image_path: Optional[str] = None) -> str:
"""Send a message and get a response."""
self.add_message("user", message, image_path)
messages, image = self.format_prompt()
print(f"{messages=}")
print(f"{image=}")

response = self.generator.generate(messages, image)

Expand All @@ -75,23 +73,49 @@ def reset(self):

model_config = ModelConfig(
checkpoint_path=Path("checkpoints/rhymes-ai/Aria/model.pth"),
compile=True,
)
generation_config = GenerationConfig(
max_new_tokens=500,
top_k=200,
max_new_tokens=4096,
top_k=40,
temperature=0.8,
cache_size=8192,
)

chat = AriaChat(model_config, generation_config)

# Chat without images
response = chat.chat("Hello! Who are you?")
print(response)

# Chat with an image
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
response = chat.chat("Describe the image", image_path)
print(response)

# Reset the chat
chat.reset()
# Add welcome message and command instructions
print("\n=== AriaChat Terminal Interface ===")
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
print("\nType your message or command to begin...")

while True:
user_input = input("\n> You: ").strip()

if user_input.lower() == "quit":
break
elif user_input.lower() == "help":
print("\nAvailable commands:")
print(" 'help' - Show this help message")
print(" 'quit' - Exit the chat")
print(" 'reset' - Clear chat history")
print(" 'image' - Chat with an image")
continue
elif user_input.lower() == "reset":
chat.reset()
print("Chat history cleared.")
continue
elif user_input.lower() == "image":
image_path = input("Enter image path or URL: ").strip()
message = input("Enter your message about the image: ").strip()
response = chat.chat(message, image_path)
else:
response = chat.chat(user_input)

print(f"\n> Aria: {response}")

print("\nGoodbye!")
14 changes: 11 additions & 3 deletions gptfast/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
import sys
import time
from pathlib import Path
Expand Down Expand Up @@ -222,22 +223,28 @@ def load_model_and_processor(checkpoint_path, device, precision):


def setup_model_compilation(
model, compile, compile_prefill, apply_regional_compilation
model, compile, compile_prefill, apply_regional_compilation, device
):
print("Compiling model...")
t0 = time.time()
if apply_regional_compilation:
print("Compiling Model")
for layer in model.llm.layers:
layer.compile()

if compile:
print("Compiling Model")
global decode_one_token, prefill
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
)
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

# warmup
for _ in range(3):
input_ids = torch.tensor([1] * random.randint(10, 100), device=device)
generate(model, input_ids=torch.tensor([1], device=device), max_new_tokens=5)
print(f"Compilation done in {time.time() - t0:.02f} seconds")


class GenerationConfig:
"""Configuration class for text generation parameters."""
Expand Down Expand Up @@ -302,6 +309,7 @@ def _setup_model(self):
self.model_config.compile,
self.model_config.compile_prefill,
self.model_config.apply_regional_compilation,
self.model_config.device,
)

def generate(
Expand Down

0 comments on commit f8f0511

Please sign in to comment.