Skip to content

Commit

Permalink
oai assistant agent example and custom event handler for streaming mo…
Browse files Browse the repository at this point in the history
…de (autogenhub#56)

* oai assistant agent example

* wip

* open ai assistant with custom event handler

* doc
  • Loading branch information
ekzhu authored Jun 7, 2024
1 parent b4ade8b commit 21b730e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 11 deletions.
4 changes: 3 additions & 1 deletion examples/chess_game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This is an example of simulating a chess game with two agents
that play against each other, using tools to reason about the game state
and make moves."""
and make moves.
You must have OPENAI_API_KEY set up in your environment to run this example.
"""

import argparse
import asyncio
Expand Down
133 changes: 133 additions & 0 deletions examples/oai_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""This is an example of a chat with an OAI assistant agent.
You must have OPENAI_API_KEY set up in your environment to
run this example.
"""

from typing import Any

import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import RespondNow, TextMessage
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantStreamEvent
from openai.types.beta.threads import Text, TextDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta
from typing_extensions import override


class TwoAgentChatOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: Any) -> None:
pass

def get_output(self) -> Any:
return None

def reset(self) -> None:
pass


sep = "-" * 50


class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore
def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore
super().__init__(
name=name,
description="A human user",
runtime=runtime,
)

@message_handler() # type: ignore
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
# TODO: render image if message has image.
# print(f"{message.source}: {message.content}")
pass

@message_handler() # type: ignore
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # type: ignore
user_input = input(f"\n{sep}\nYou: ")
# TODO: add parsing for special commands e.g., upload files, exit, etc.
return TextMessage(content=user_input, source=self.name)


class EventHandler(AsyncAssistantEventHandler):
@override
async def on_event(self, event: AssistantStreamEvent) -> None:
if event.event == "thread.run.step.created":
details = event.data.step_details
if details.type == "tool_calls":
print("\nGenerating code to interpret:\n\n```python")
elif event.event == "thread.message.created":
print(f"{sep}\nAssistant:\n")

@override
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
print(delta.value, end="", flush=True)

@override
async def on_run_step_done(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\n```\nExecuting code...")

@override
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
details = delta.step_details
if details is not None and details.type == "tool_calls":
for tool in details.tool_calls or []:
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
print(tool.code_interpreter.input, end="", flush=True)


def assistant_chat(runtime: AgentRuntime) -> TwoAgentChat: # type: ignore
user = UserProxyAgent(name="User", runtime=runtime)
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
instructions="Help the user with their task.",
tools=[{"type": "code_interpreter"}],
)
thread = openai.beta.threads.create()
assistant = OpenAIAssistantAgent(
name="Assistant",
description="An AI assistant that helps with everyday tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
assistant_event_handler_factory=lambda: EventHandler(),
)
return TwoAgentChat(
name="AssistantChat",
description="A chat with an AI assistant",
runtime=runtime,
initial_sender=user,
initial_recipient=assistant,
num_rounds=100,
output=TwoAgentChatOutput(),
)


async def main() -> None:
runtime = SingleThreadedAgentRuntime()
chat = assistant_chat(runtime)
future = runtime.send_message(
TextMessage(content="Hello.", source="User"),
chat,
)
while not future.done():
await runtime.process_next()


if __name__ == "__main__":
import asyncio

asyncio.run(main())
27 changes: 20 additions & 7 deletions src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, List, Mapping
from typing import Any, Callable, List, Mapping

import openai
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantResponseFormatParam

from agnext.chat.agents.base import BaseChatAgent
Expand All @@ -18,11 +19,13 @@ def __init__(
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
assistant_event_handler_factory: Callable[[], AsyncAssistantEventHandler] | None = None,
) -> None:
super().__init__(name, description, runtime)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._assistant_event_handler_factory = assistant_event_handler_factory

@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
Expand Down Expand Up @@ -60,12 +63,22 @@ async def on_respond_now(self, message: RespondNow, cancellation_token: Cancella
else:
response_format = AssistantResponseFormatParam(type="text")

# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
response_format=response_format,
)
if self._assistant_event_handler_factory is not None:
# Use event handler and streaming mode if available.
async with self._client.beta.threads.runs.stream(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
event_handler=self._assistant_event_handler_factory(),
response_format=response_format,
) as stream:
run = await stream.get_final_run()
else:
# Use blocking mode.
run = await self._client.beta.threads.runs.create(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
response_format=response_format,
)

if run.status != "completed":
# TODO: handle other statuses.
Expand Down
15 changes: 12 additions & 3 deletions src/agnext/chat/patterns/two_agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@
from ..agents.base import BaseChatAgent


# TODO: rewrite this with a new message type calling for add to message
# history.
class TwoAgentChat(GroupChat):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agent1: BaseChatAgent,
agent2: BaseChatAgent,
initial_sender: BaseChatAgent,
initial_recipient: BaseChatAgent,
num_rounds: int,
output: GroupChatOutput,
) -> None:
super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output)
super().__init__(
name,
description,
runtime,
[initial_recipient, initial_sender],
num_rounds,
output,
)

0 comments on commit 21b730e

Please sign in to comment.