Skip to content

Commit

Permalink
Improve assistant and chess examples to make them more robust and bet…
Browse files Browse the repository at this point in the history
…ter presentation. (autogenhub#58)

* Improve assistant and chess examples to make them more robust and better presentation.

* type dep

* format
  • Loading branch information
ekzhu authored Jun 8, 2024
1 parent 21b730e commit 37cc6bc
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 172 deletions.
239 changes: 239 additions & 0 deletions examples/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""This is an example of a chat with an OpenAIAssistantAgent.
You must have OPENAI_API_KEY set up in your environment to
run this example.
"""

import os
import re
from typing import Any, List

import aiofiles
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import RespondNow, TextMessage
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from openai import AsyncAssistantEventHandler
from openai.types.beta.thread import ToolResources
from openai.types.beta.threads import Message, Text, TextDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta
from typing_extensions import override


class TwoAgentChatOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: Any) -> None:
pass

def get_output(self) -> Any:
return None

def reset(self) -> None:
pass


sep = "-" * 50


class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore
def __init__(
self,
name: str,
runtime: AgentRuntime,
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
vector_store_id: str,
) -> None: # type: ignore
super().__init__(
name=name,
description="A human user",
runtime=runtime,
)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._vector_store_id = vector_store_id

@message_handler() # type: ignore
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
# TODO: render image if message has image.
# print(f"{message.source}: {message.content}")
pass

@message_handler() # type: ignore
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # type: ignore
while True:
user_input = input(f"\n{sep}\nYou: ")
# Parse upload file command '[upload code_interpreter | file_search filename]'.
match = re.search(r"\[upload\s+(code_interpreter|file_search)\s+(.+)\]", user_input)
if match:
# Purpose of the file.
purpose = match.group(1)
# Extract file path.
file_path = match.group(2)
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
continue
# Filename.
file_name = os.path.basename(file_path)
# Read file content.
async with aiofiles.open(file_path, "rb") as f:
file_content = await f.read()
if purpose == "code_interpreter":
# Upload file.
file = await self._client.files.create(file=(file_name, file_content), purpose="assistants")
# Get existing file ids from tool resources.
thread = await self._client.beta.threads.retrieve(thread_id=self._thread_id)
tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources()
assert tool_resources.code_interpreter is not None
if tool_resources.code_interpreter.file_ids:
file_ids = tool_resources.code_interpreter.file_ids
else:
file_ids = [file.id]
# Update thread with new file.
await self._client.beta.threads.update(
thread_id=self._thread_id,
tool_resources={"code_interpreter": {"file_ids": file_ids}},
)
elif purpose == "file_search":
# Upload file to vector store.
file_batch = await self._client.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=self._vector_store_id,
files=[(file_name, file_content)],
)
assert file_batch.status == "completed"
print(f"Uploaded file: {file_name}")
continue
elif user_input.startswith("[upload"):
print("Invalid upload command. Please use '[upload code_interpreter | file_search filename]'.")
continue
else:
# Send user input to assistant.
return TextMessage(content=user_input, source=self.name)


class EventHandler(AsyncAssistantEventHandler):
@override
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
print(delta.value, end="", flush=True)

@override
async def on_run_step_created(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\nGenerating code to interpret:\n\n```python")

@override
async def on_run_step_done(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\n```\nExecuting code...")

@override
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
details = delta.step_details
if details is not None and details.type == "tool_calls":
for tool in details.tool_calls or []:
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
print(tool.code_interpreter.input, end="", flush=True)

@override
async def on_message_created(self, message: Message) -> None:
print(f"{sep}\nAssistant:\n")

@override
async def on_message_done(self, message: Message) -> None:
# print a citation to the file searched
if not message.content:
return
content = message.content[0]
if not content.type == "text":
return
text_content = content.text
annotations = text_content.annotations
citations: List[str] = []
for index, annotation in enumerate(annotations):
text_content.value = text_content.value.replace(annotation.text, f"[{index}]")
if file_citation := getattr(annotation, "file_citation", None):
client = openai.AsyncClient()
cited_file = await client.files.retrieve(file_citation.file_id)
citations.append(f"[{index}] {cited_file.filename}")
if citations:
print("\n".join(citations))


def assistant_chat(runtime: AgentRuntime) -> TwoAgentChat: # type: ignore
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
instructions="Help the user with their task.",
tools=[{"type": "code_interpreter"}, {"type": "file_search"}],
)
vector_store = openai.beta.vector_stores.create()
thread = openai.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant = OpenAIAssistantAgent(
name="Assistant",
description="An AI assistant that helps with everyday tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
assistant_event_handler_factory=lambda: EventHandler(),
)
user = UserProxyAgent(
name="User",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
vector_store_id=vector_store.id,
)
return TwoAgentChat(
name="AssistantChat",
description="A chat with an AI assistant",
runtime=runtime,
first_speaker=assistant,
second_speaker=user,
num_rounds=100,
output=TwoAgentChatOutput(),
)


async def main() -> None:
usage = """Chat with an AI assistant backed by OpenAI Assistant API.
You can upload files to the assistant using the command:
[upload code_interpreter | file_search filename]
where 'code_interpreter' or 'file_search' is the purpose of the file and
'filename' is the path to the file. For example:
[upload code_interpreter data.csv]
This will upload data.csv to the assistant for use with the code interpreter tool.
"""
runtime = SingleThreadedAgentRuntime()
chat = assistant_chat(runtime)
print(usage)
future = runtime.send_message(
TextMessage(content="Hello.", source="User"),
chat,
)
while not future.done():
await runtime.process_next()


if __name__ == "__main__":
import asyncio

asyncio.run(main())
Loading

0 comments on commit 37cc6bc

Please sign in to comment.