Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Realtime agent with WebSocket #241

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogen/agentchat/realtime_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .twilio_observer import TwilioAudioAdapter
from .websocket_observer import WebsocketAudioAdapter

__all__ = [
"RealtimeAgent",
"FunctionObserver",
"TwilioAudioAdapter",
"WebsocketAudioAdapter"
]
128 changes: 128 additions & 0 deletions autogen/agentchat/realtime_agent/websocket_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import base64
import json

from fastapi import WebSocketDisconnect

from .realtime_observer import RealtimeObserver

LOG_EVENT_TYPES = [
"error",
"response.content.done",
"rate_limits.updated",
"response.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_stopped",
"input_audio_buffer.speech_started",
"session.created",
]
SHOW_TIMING_MATH = False


class WebsocketAudioAdapter(RealtimeObserver):
def __init__(self, websocket):
super().__init__()
self.websocket = websocket

# Connection specific state
self.stream_sid = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
self.mark_queue = []
self.response_start_timestamp_socket = None

async def update(self, response):
"""Receive events from the OpenAI Realtime API, send audio back to websocket."""
if response["type"] in LOG_EVENT_TYPES:
print(f"Received event: {response['type']}", response)

if response.get("type") == "response.audio.delta" and "delta" in response:
audio_payload = base64.b64encode(base64.b64decode(response["delta"])).decode("utf-8")
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
await self.websocket.send_json(audio_delta)

if self.response_start_timestamp_socket is None:
self.response_start_timestamp_socket = self.latest_media_timestamp
if SHOW_TIMING_MATH:
print(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms")

# Update last_assistant_item safely
if response.get("item_id"):
self.last_assistant_item = response["item_id"]

await self.send_mark()

# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
if response.get("type") == "input_audio_buffer.speech_started":
print("Speech started detected.")
if self.last_assistant_item:
print(f"Interrupting response with id: {self.last_assistant_item}")
await self.handle_speech_started_event()

async def handle_speech_started_event(self):
"""Handle interruption when the caller's speech starts."""
print("Handling speech started event.")
if self.mark_queue and self.response_start_timestamp_socket is not None:
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket
if SHOW_TIMING_MATH:
print(
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms"
)

if self.last_assistant_item:
if SHOW_TIMING_MATH:
print(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")

truncate_event = {
"type": "conversation.item.truncate",
"item_id": self.last_assistant_item,
"content_index": 0,
"audio_end_ms": elapsed_time,
}
await self._client._openai_ws.send(json.dumps(truncate_event))

await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})

self.mark_queue.clear()
self.last_assistant_item = None
self.response_start_timestamp_socket = None

async def send_mark(self):
if self.stream_sid:
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
await self.websocket.send_json(mark_event)
self.mark_queue.append("responsePart")

async def run(self):
openai_ws = self._client._openai_ws
await self.initialize_session()

async for message in self.websocket.iter_text():
data = json.loads(message)
if data["event"] == "media":
self.latest_media_timestamp = int(data["media"]["timestamp"])
audio_append = {"type": "input_audio_buffer.append", "audio": data["media"]["payload"]}
await openai_ws.send(json.dumps(audio_append))
elif data["event"] == "start":
self.stream_sid = data["start"]["streamSid"]
print(f"Incoming stream has started {self.stream_sid}")
self.response_start_timestamp_socket = None
self.latest_media_timestamp = 0
self.last_assistant_item = None
elif data["event"] == "mark":
if self.mark_queue:
self.mark_queue.pop(0)

async def initialize_session(self):
"""Control initial session with OpenAI."""
session_update = {
"input_audio_format": "pcm16", # g711_ulaw
"output_audio_format": "pcm16" # "g711_ulaw",
}
await self._client.session_update(session_update)
137 changes: 137 additions & 0 deletions notebook/agentchat_realtime_websocket.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from typing import Annotated, Union\n",
"from pathlib import Path\n",
"\n",
"import nest_asyncio\n",
"import uvicorn\n",
"from fastapi import FastAPI, Request, WebSocket\n",
"from fastapi.responses import HTMLResponse, JSONResponse\n",
"from fastapi.templating import Jinja2Templates\n",
"from fastapi.staticfiles import StaticFiles\n",
"\n",
"from autogen.agentchat.realtime_agent import FunctionObserver, RealtimeAgent, WebsocketAudioAdapter\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration\n",
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
"PORT = int(os.getenv(\"PORT\", 5050))\n",
"\n",
"if not OPENAI_API_KEY:\n",
" raise ValueError(\"Missing the OpenAI API key. Please set it in the .env file.\")\n",
"\n",
"llm_config = {\n",
" \"timeout\": 600,\n",
" \"cache_seed\": 45, # change the seed for different trials\n",
" \"config_list\": [\n",
" {\n",
" \"model\": \"gpt-4o-realtime-preview-2024-10-01\",\n",
" \"api_key\": OPENAI_API_KEY,\n",
" }\n",
" ],\n",
" \"temperature\": 0.8,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"app = FastAPI()\n",
"\n",
"notebook_path=os.getcwd()\n",
"\n",
"app.mount(\"/static\", StaticFiles(directory=Path(notebook_path) / \"agentchat_realtime_websocket\" / \"static\"), name=\"static\")\n",
"\n",
"# Templates for HTML responses\n",
"\n",
"templates = Jinja2Templates(directory=Path(notebook_path) / \"agentchat_realtime_websocket\" / \"templates\")\n",
"\n",
"@app.get(\"/\", response_class=JSONResponse)\n",
"async def index_page():\n",
" return {\"message\": \"Websocket Audio Stream Server is running!\"}\n",
"\n",
"@app.get(\"/start-chat/\", response_class=HTMLResponse)\n",
"async def start_chat(request: Request):\n",
" \"\"\"Endpoint to return the HTML page for audio chat.\"\"\"\n",
" port = PORT # Extract the client's port\n",
" return templates.TemplateResponse(\"chat.html\", {\"request\": request, \"port\": port})\n",
"\n",
"@app.websocket(\"/media-stream\")\n",
"async def handle_media_stream(websocket: WebSocket):\n",
" \"\"\"Handle WebSocket connections providing audio stream and OpenAI.\"\"\"\n",
" await websocket.accept()\n",
"\n",
" audio_adapter = WebsocketAudioAdapter(websocket)\n",
" openai_client = RealtimeAgent(\n",
" name=\"Weather Bot\",\n",
" system_message=\"Hello there! I am an AI voice assistant powered by Autogen and the OpenAI Realtime API. You can ask me about weather, jokes, or anything you can imagine. Start by saying How can I help you?\",\n",
" llm_config=llm_config,\n",
" audio_adapter=audio_adapter,\n",
" )\n",
"\n",
" @openai_client.register_handover(name=\"get_weather\", description=\"Get the current weather\")\n",
" def get_weather(location: Annotated[str, \"city\"]) -> str:\n",
" ...\n",
" return \"The weather is cloudy.\" if location == \"Seattle\" else \"The weather is sunny.\"\n",
"\n",
" await openai_client.run()\n",
"\n",
"\n",
"uvicorn.run(app, host=\"0.0.0.0\", port=PORT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading