-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #196 from ag2ai/realtime-agent
Implement RealtimeAgent for Real-Time Conversational AI Support in ag2 Framework
- Loading branch information
Showing
25 changed files
with
2,155 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .function_observer import FunctionObserver | ||
from .realtime_agent import RealtimeAgent | ||
from .twilio_observer import TwilioAudioAdapter | ||
from .websocket_observer import WebsocketAudioAdapter | ||
|
||
__all__ = ["RealtimeAgent", "FunctionObserver", "TwilioAudioAdapter", "WebsocketAudioAdapter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License. | ||
# SPDX-License-Identifier: MIT | ||
|
||
# import asyncio | ||
import json | ||
import logging | ||
from typing import Any, Optional | ||
|
||
import anyio | ||
import websockets | ||
from asyncer import TaskGroup, asyncify, create_task_group, syncify | ||
|
||
from autogen.agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat | ||
|
||
from .function_observer import FunctionObserver | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OpenAIRealtimeClient: | ||
"""(Experimental) Client for OpenAI Realtime API.""" | ||
|
||
def __init__(self, agent, audio_adapter, function_observer: FunctionObserver): | ||
"""(Experimental) Client for OpenAI Realtime API. | ||
args: | ||
agent: Agent instance | ||
the agent to be used for the conversation | ||
audio_adapter: RealtimeObserver | ||
adapter for streaming the audio from the client | ||
function_observer: FunctionObserver | ||
observer for handling function calls | ||
""" | ||
self._agent = agent | ||
self._observers = [] | ||
self._openai_ws = None # todo factor out to OpenAIClient | ||
self.register(audio_adapter) | ||
self.register(function_observer) | ||
|
||
# LLM config | ||
llm_config = self._agent.llm_config | ||
|
||
config = llm_config["config_list"][0] | ||
|
||
self.model = config["model"] | ||
self.temperature = llm_config["temperature"] | ||
self.api_key = config["api_key"] | ||
|
||
# create a task group to manage the tasks | ||
self.tg: Optional[TaskGroup] = None | ||
|
||
def register(self, observer): | ||
"""Register an observer to the client.""" | ||
observer.register_client(self) | ||
self._observers.append(observer) | ||
|
||
async def notify_observers(self, message): | ||
"""Notify all observers of a message from the OpenAI Realtime API.""" | ||
for observer in self._observers: | ||
await observer.update(message) | ||
|
||
async def function_result(self, call_id, result): | ||
"""Send the result of a function call to the OpenAI Realtime API.""" | ||
result_item = { | ||
"type": "conversation.item.create", | ||
"item": { | ||
"type": "function_call_output", | ||
"call_id": call_id, | ||
"output": result, | ||
}, | ||
} | ||
await self._openai_ws.send(json.dumps(result_item)) | ||
await self._openai_ws.send(json.dumps({"type": "response.create"})) | ||
|
||
async def send_text(self, *, role: str, text: str): | ||
"""Send a text message to the OpenAI Realtime API.""" | ||
await self._openai_ws.send(json.dumps({"type": "response.cancel"})) | ||
text_item = { | ||
"type": "conversation.item.create", | ||
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, | ||
} | ||
await self._openai_ws.send(json.dumps(text_item)) | ||
await self._openai_ws.send(json.dumps({"type": "response.create"})) | ||
|
||
# todo override in specific clients | ||
async def initialize_session(self): | ||
"""Control initial session with OpenAI.""" | ||
session_update = { | ||
# todo: move to config | ||
"turn_detection": {"type": "server_vad"}, | ||
"voice": self._agent.voice, | ||
"instructions": self._agent.system_message, | ||
"modalities": ["audio", "text"], | ||
"temperature": 0.8, | ||
} | ||
await self.session_update(session_update) | ||
|
||
# todo override in specific clients | ||
async def session_update(self, session_options): | ||
"""Send a session update to the OpenAI Realtime API.""" | ||
update = {"type": "session.update", "session": session_options} | ||
logger.info("Sending session update:", json.dumps(update)) | ||
await self._openai_ws.send(json.dumps(update)) | ||
logger.info("Sending session update finished") | ||
|
||
async def _read_from_client(self): | ||
"""Read messages from the OpenAI Realtime API.""" | ||
try: | ||
async for openai_message in self._openai_ws: | ||
response = json.loads(openai_message) | ||
await self.notify_observers(response) | ||
except Exception as e: | ||
logger.warning(f"Error in _read_from_client: {e}") | ||
|
||
async def run(self): | ||
"""Run the client.""" | ||
async with websockets.connect( | ||
f"wss://api.openai.com/v1/realtime?model={self.model}", | ||
additional_headers={ | ||
"Authorization": f"Bearer {self.api_key}", | ||
"OpenAI-Beta": "realtime=v1", | ||
}, | ||
) as openai_ws: | ||
self._openai_ws = openai_ws | ||
await self.initialize_session() | ||
# await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers]) | ||
async with create_task_group() as tg: | ||
self.tg = tg | ||
self.tg.soonify(self._read_from_client)() | ||
for observer in self._observers: | ||
self.tg.soonify(observer.run)() | ||
if self._agent._start_swarm_chat: | ||
self.tg.soonify(asyncify(initiate_swarm_chat))( | ||
initial_agent=self._agent._initial_agent, | ||
agents=self._agent._agents, | ||
user_agent=self._agent, | ||
messages="Find out what the user wants.", | ||
after_work=AfterWorkOption.REVERT_TO_USER, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License. | ||
# SPDX-License-Identifier: MIT | ||
|
||
import asyncio | ||
import json | ||
import logging | ||
|
||
from asyncer import asyncify | ||
from pydantic import BaseModel | ||
|
||
from .realtime_observer import RealtimeObserver | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FunctionObserver(RealtimeObserver): | ||
"""Observer for handling function calls from the OpenAI Realtime API.""" | ||
|
||
def __init__(self, agent): | ||
"""Observer for handling function calls from the OpenAI Realtime API. | ||
Args: | ||
agent: Agent instance | ||
the agent to be used for the conversation | ||
""" | ||
super().__init__() | ||
self._agent = agent | ||
|
||
async def update(self, response): | ||
"""Handle function call events from the OpenAI Realtime API.""" | ||
if response.get("type") == "response.function_call_arguments.done": | ||
logger.info(f"Received event: {response['type']}", response) | ||
await self.call_function( | ||
call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"]) | ||
) | ||
|
||
async def call_function(self, call_id, name, kwargs): | ||
"""Call a function registered with the agent.""" | ||
if name in self._agent.realtime_functions: | ||
_, func = self._agent.realtime_functions[name] | ||
func = func if asyncio.iscoroutinefunction(func) else asyncify(func) | ||
try: | ||
result = await func(**kwargs) | ||
except Exception: | ||
result = "Function call failed" | ||
logger.warning(f"Function call failed: {name}") | ||
|
||
if isinstance(result, BaseModel): | ||
result = result.model_dump_json() | ||
elif not isinstance(result, str): | ||
result = json.dumps(result) | ||
|
||
await self._client.function_result(call_id, result) | ||
|
||
async def run(self): | ||
"""Run the observer. | ||
Initialize the session with the OpenAI Realtime API. | ||
""" | ||
await self.initialize_session() | ||
|
||
async def initialize_session(self): | ||
"""Add registered tools to OpenAI with a session update.""" | ||
session_update = { | ||
"tools": [schema for schema, _ in self._agent.realtime_functions.values()], | ||
"tool_choice": "auto", | ||
} | ||
await self._client.session_update(session_update) |
Oops, something went wrong.