Skip to content

Commit

Permalink
Merge pull request #313 from ag2ai/rewrite-OpenAIRealtimeClient-davor
Browse files Browse the repository at this point in the history
Refactoring of PR 281
  • Loading branch information
davorrunje authored Dec 30, 2024
2 parents 31c8299 + 7ec19ae commit c09fd02
Show file tree
Hide file tree
Showing 19 changed files with 1,025 additions and 596 deletions.
7 changes: 6 additions & 1 deletion autogen/agentchat/realtime_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver
from .twilio_observer import TwilioAudioAdapter
from .websocket_observer import WebsocketAudioAdapter

__all__ = ["RealtimeAgent", "FunctionObserver", "TwilioAudioAdapter", "WebsocketAudioAdapter"]
__all__ = ["FunctionObserver", "RealtimeAgent", "RealtimeObserver", "TwilioAudioAdapter", "WebsocketAudioAdapter"]
201 changes: 0 additions & 201 deletions autogen/agentchat/realtime_agent/client.py

This file was deleted.

68 changes: 29 additions & 39 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,37 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Optional

from asyncer import asyncify
from openai.types.beta.realtime.realtime_server_event import RealtimeServerEvent
from pydantic import BaseModel

from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from .realtime_agent import RealtimeAgent

logger = logging.getLogger(__name__)


class FunctionObserver(RealtimeObserver):
"""Observer for handling function calls from the OpenAI Realtime API."""

def __init__(self, agent: "RealtimeAgent") -> None:
"""Observer for handling function calls from the OpenAI Realtime API.
Args:
agent (RealtimeAgent): The realtime agent attached to the observer.
"""
super().__init__()
self._agent = agent
def __init__(self, *, logger: Optional[Logger] = None) -> None:
"""Observer for handling function calls from the OpenAI Realtime API."""
super().__init__(logger=logger)

async def update(self, event: RealtimeServerEvent) -> None:
async def on_event(self, event: dict[str, Any]) -> None:
"""Handle function call events from the OpenAI Realtime API.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if event.type == "response.function_call_arguments.done":
logger.info(f"Received event: {event.type}", event)
if event["type"] == "response.function_call_arguments.done":
self.logger.info(f"Received event: {event['type']}", event)
await self.call_function(
call_id=event.call_id,
name=event.name, # type: ignore [attr-defined]
kwargs=json.loads(event.arguments),
call_id=event["call_id"],
name=event["name"],
kwargs=json.loads(event["arguments"]),
)

async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
Expand All @@ -57,33 +43,37 @@ async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -
kwargs (Any[str, Any]): The arguments to pass to the function.
"""

if name in self._agent.realtime_functions:
_, func = self._agent.realtime_functions[name]
if name in self.agent._registred_realtime_functions:
_, func = self.agent._registred_realtime_functions[name]
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
try:
result = await func(**kwargs)
except Exception:
result = "Function call failed"
logger.warning(f"Function call failed: {name}")
self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True)

if isinstance(result, BaseModel):
result = result.model_dump_json()
elif not isinstance(result, str):
result = json.dumps(result)
try:
result = json.dumps(result)
except Exception:
result = str(result)

await self.client.function_result(call_id, result)

async def run(self) -> None:
"""Run the observer.
Initialize the session with the OpenAI Realtime API.
"""
await self.initialize_session()
await self.realtime_client.send_function_result(call_id, result)

async def initialize_session(self) -> None:
"""Add registered tools to OpenAI with a session update."""
session_update = {
"tools": [schema for schema, _ in self._agent.realtime_functions.values()],
"tools": [schema for schema, _ in self.agent._registred_realtime_functions.values()],
"tool_choice": "auto",
}
await self.client.session_update(session_update)
await self.realtime_client.session_update(session_update)

async def run_loop(self) -> None:
"""Run the observer loop."""
pass


if TYPE_CHECKING:
function_observer: RealtimeObserver = FunctionObserver()
Loading

0 comments on commit c09fd02

Please sign in to comment.