Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Dec 31, 2024
1 parent 887a8ef commit 6034a4d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 40 deletions.
64 changes: 53 additions & 11 deletions test/agentchat/realtime_agent/realtime_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,71 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import os
from functools import wraps
from typing import Any, Callable
from typing import Any, Callable, Literal, Optional, TypeVar, Union
from unittest.mock import MagicMock

from anyio import Event
from openai import OpenAI
from openai import NotGiven, OpenAI


def generate_voice_input(text: str) -> str:
tts_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
response = tts_client.audio.speech.create(model="tts-1", voice="alloy", input=text, response_format="pcm")
def text_to_speech(
*,
text: str,
openai_api_key: str,
model: str = "tts-1",
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
response_format: Union[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], NotGiven] = "pcm",
) -> str:
"""Convert text to voice using OpenAI API.
Args:
text (str): Text to convert to voice.
openai_api_key (str): OpenAI API key.
model (str, optional): Model to use for the conversion. Defaults to "tts-1".
voice (Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], optional): Voice to use for the conversion. Defaults to "alloy".
response_format (Union[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], NotGiven], optional): Response format. Defaults to "pcm".
Returns:
str: Base64 encoded audio.
"""
tts_client = OpenAI(api_key=openai_api_key)
response = tts_client.audio.speech.create(model=model, voice=voice, input=text, response_format=response_format)
return base64.b64encode(response.content).decode("utf-8")


def trace(mock: MagicMock, event: Event) -> Callable[..., Any]:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
F = TypeVar("F", bound=Callable[..., Any])


def trace(
mock: MagicMock, *, precall_event: Optional[Event] = None, postcall_event: Optional[Event] = None
) -> Callable[[F], F]:
"""Decorator to trace a function
Mock will be called before the function.
If defined, precall_event will be set before the function call and postcall_event will be set after the function call.
Args:
mock (MagicMock): Mock object.
precall_event (Optional[Event], optional): Event to set before the function call. Defaults to None.
postcall_event (Optional[Event], optional): Event to set after the function call. Defaults to None.
Returns:
Callable[[F], F]: Function decorator.
"""

def decorator(f: F) -> F:
@wraps(f)
def _inner(*args: Any, **kwargs: Any) -> Any:
mock(*args, **kwargs)
event.set()
return f(*args, **kwargs)
if precall_event is not None:
precall_event.set()
retval = f(*args, **kwargs)
if postcall_event is not None:
postcall_event.set()

return retval

return _inner
return _inner # type: ignore[return-value]

return decorator
36 changes: 29 additions & 7 deletions test/agentchat/realtime_agent/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from logging import getLogger
from typing import Annotated, Any
from unittest.mock import MagicMock

Expand All @@ -16,13 +17,16 @@
import autogen
from autogen.agentchat.realtime_agent import RealtimeAgent, RealtimeObserver, WebSocketAudioAdapter

from .realtime_test_utils import generate_voice_input, trace
from .realtime_test_utils import text_to_speech, trace

logger = getLogger(__name__)


@pytest.mark.skipif(skip_openai, reason=reason)
class TestE2E:
@pytest.fixture
def llm_config(self) -> dict[str, Any]:
"""Fixture to load the LLM config."""
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
Expand All @@ -36,7 +40,17 @@ def llm_config(self) -> dict[str, Any]:
"temperature": 0.0,
}

async def _test_e2e(self, llm_config: dict[str, Any]) -> None:
@pytest.fixture
def openai_api_key(self, llm_config: dict[str, Any]) -> str:
"""Fixture to get the OpenAI API key."""
return llm_config["config_list"][0]["api_key"] # type: ignore[no-any-return]

async def _test_e2e(self, llm_config: dict[str, Any], openai_api_key: str) -> None:
"""End-to-end test for the RealtimeAgent.
Create a FastAPI app with a WebSocket endpoint that handles audio stream and OpenAI.
"""
# Event for synchronization and tracking state
weather_func_called_event = Event()
weather_func_mock = MagicMock()
Expand All @@ -60,7 +74,7 @@ async def handle_media_stream(websocket: WebSocket) -> None:
agent.register_observer(mock_observer)

@agent.register_realtime_function(name="get_weather", description="Get the current weather")
@trace(weather_func_mock, weather_func_called_event)
@trace(weather_func_mock, postcall_event=weather_func_called_event)
def get_weather(location: Annotated[str, "city"]) -> str:
return "The weather is cloudy." if location == "Seattle" else "The weather is sunny."

Expand All @@ -80,7 +94,7 @@ def get_weather(location: Annotated[str, "city"]) -> str:
"event": "media",
"media": {
"timestamp": 0,
"payload": generate_voice_input(text="How is the weather in Seattle?"),
"payload": text_to_speech(text="How is the weather in Seattle?", openai_api_key=openai_api_key),
},
}
)
Expand All @@ -102,14 +116,22 @@ def get_weather(location: Annotated[str, "city"]) -> str:
assert "cloudy" in last_response_transcript, "Weather response did not include the weather condition"

@pytest.mark.asyncio()
async def test_e2e(self, llm_config: dict[str, Any]) -> None:
async def test_e2e(self, llm_config: dict[str, Any], openai_api_key: str) -> None:
"""End-to-end test for the RealtimeAgent.
Retry the test up to 3 times if it fails. Sometimes the test fails due to voice not being recognized by the OpenAI API.
"""
last_exception = None

for _ in range(3):
for i in range(3):
try:
await self._test_e2e(llm_config)
await self._test_e2e(llm_config, openai_api_key=openai_api_key)
return # Exit the function if the test passes
except Exception as e:
logger.warning(
f"Test 'TestE2E.test_e2e' failed on attempt {i + 1} with exception: {e}", stack_info=True
)
last_exception = e # Keep track of the last exception

# If the loop finishes without success, raise the last exception
Expand Down
21 changes: 0 additions & 21 deletions test/agentchat/realtime_agent/test_oai_realtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_init(self) -> None:
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_not_connected(self, client: OpenAIRealtimeClient) -> None:

with pytest.raises(RuntimeError, match=r"Client is not connected, call connect\(\) first."):
with move_on_after(1) as scope:
async for _ in client.read_events():
Expand All @@ -71,7 +70,6 @@ async def test_not_connected(self, client: OpenAIRealtimeClient) -> None:
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_start_read_events(self, client: OpenAIRealtimeClient) -> None:

mock = MagicMock()

async with client.connect():
Expand All @@ -94,7 +92,6 @@ async def test_start_read_events(self, client: OpenAIRealtimeClient) -> None:
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_send_text(self, client: OpenAIRealtimeClient) -> None:

mock = MagicMock()

async with client.connect():
Expand All @@ -121,21 +118,3 @@ async def test_send_text(self, client: OpenAIRealtimeClient) -> None:

assert calls_kwargs[3]["type"] == "conversation.item.created"
assert calls_kwargs[3]["item"]["content"][0]["text"] == "Hello, how are you?"

@pytest.mark.skip(reason="Not implemented")
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_send_audio(self, client: OpenAIRealtimeClient) -> None:
raise NotImplementedError

@pytest.mark.skip(reason="Not implemented")
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_truncate_audio(self, client: OpenAIRealtimeClient) -> None:
raise NotImplementedError

@pytest.mark.skip(reason="Not implemented")
@pytest.mark.skipif(skip_openai, reason=reason)
@pytest.mark.asyncio()
async def test_initialize_session(self, client: OpenAIRealtimeClient) -> None:
raise NotImplementedError
1 change: 0 additions & 1 deletion test/agentchat/realtime_agent/test_realtime_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ async def on_event(self, event: dict[str, Any]) -> None:
class TestRealtimeObserver:
@pytest.mark.asyncio()
async def test_shutdown(self) -> None:

mock = MagicMock()
observer = MyObserver(mock)

Expand Down

0 comments on commit 6034a4d

Please sign in to comment.