Skip to content

Commit

Permalink
Cleanup, minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 9, 2024
1 parent dbaad0f commit c8e91a0
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 28 deletions.
3 changes: 2 additions & 1 deletion agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ToolFactory,
)
from agency_swarm.tools.oai.FileSearch import FileSearchConfig
from agency_swarm.util.constants import DEFAULT_MODEL
from agency_swarm.util.oai import get_openai_client
from agency_swarm.util.openapi import validate_openapi_spec
from agency_swarm.util.shared_state import SharedState
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(
api_params: Dict[str, Dict[str, str]] = None,
file_ids: List[str] = None,
metadata: Dict[str, str] = None,
model: str = "gpt-4o-2024-08-06",
model: str = DEFAULT_MODEL,
validation_attempts: int = 1,
max_prompt_tokens: int = None,
max_completion_tokens: int = None,
Expand Down
2 changes: 2 additions & 0 deletions agency_swarm/util/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DEFAULT_MODEL = "gpt-4o-2024-08-06"
DEFAULT_MODEL_MINI = "gpt-4o-mini"
3 changes: 2 additions & 1 deletion agency_swarm/util/streaming/agency_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from openai.types.beta.threads.runs.run_step import RunStep
from typing_extensions import override

from agency_swarm.util.constants import DEFAULT_MODEL
from agency_swarm.util.oai import get_tracker


Expand Down Expand Up @@ -43,7 +44,7 @@ def on_run_step_done(cls, run_step: RunStep) -> None:
"""
if run_step.usage:
tracker = get_tracker()
model = cls.agent.model if cls.agent else "gpt-4o"
model = cls.agent.model if cls.agent else DEFAULT_MODEL
tracker.track_usage(
usage=run_step.usage,
assistant_id=run_step.assistant_id,
Expand Down
3 changes: 2 additions & 1 deletion agency_swarm/util/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openai import OpenAI
from pydantic import BaseModel, Field

from agency_swarm.util.constants import DEFAULT_MODEL_MINI
from agency_swarm.util.oai import get_openai_client


Expand All @@ -29,7 +30,7 @@ def llm_validator(
statement: str,
client: OpenAI = None,
allow_override: bool = False,
model: str = "gpt-4o-mini",
model: str = DEFAULT_MODEL_MINI,
temperature: float = 0,
) -> Callable[[str], str]:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def setUpClass(cls):
cls.agent2 = None
cls.agency = None
cls.client = get_openai_client()
cls.client.timeout = 60.0

# testing loading agents from db
cls.loaded_thread_ids = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_send_message_swarm(self):
len(main_thread.get_messages()) >= 4
) # sometimes run does not cancel immediately, so there might be 5 messages

def test_send_message_double_recepient_error(self):
def test_send_message_double_recipient_error(self):
ceo = Agent(
name="CEO",
description="Responsible for client communication, task planning and management.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
import pytest
from openai.types.beta.threads.runs.run_step import Usage

from agency_swarm.util.tracking import LangfuseTracker, SQLiteTracker


@pytest.fixture
def sqlite_tracker():
tracker = SQLiteTracker(":memory:")
yield tracker
from agency_swarm.util.tracking import LangfuseTracker


@pytest.fixture
Expand All @@ -27,22 +21,6 @@ def test_sqlite_track_and_get_total_tokens(sqlite_tracker):
assert totals == usage


def test_sqlite_multiple_entries(sqlite_tracker):
# Insert multiple usage entries
usages = [
Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30),
]
for u in usages:
sqlite_tracker.track_usage(
u, "assistant", "thread", "gpt-4o", "sender", "recipient"
)

totals = sqlite_tracker.get_total_tokens()
# Expected totals: prompt=30, completion=15, total=45
assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45)


@patch("agency_swarm.util.tracking.langfuse_tracker.Langfuse")
def test_langfuse_track_usage(mock_langfuse, langfuse_tracker):
# Create mock instance and set it as the client
Expand Down Expand Up @@ -112,7 +90,6 @@ def test_langfuse_get_total_tokens_multiple(mock_langfuse, langfuse_tracker):
langfuse_tracker.client = mock_langfuse_instance # Set the mocked client

totals = langfuse_tracker.get_total_tokens()
# Expected totals: prompt=30, completion=15, total=45
assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45)


Expand Down
34 changes: 34 additions & 0 deletions tests/tracking/test_sqlite_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from openai.types.beta.threads.runs.run_step import Usage

from agency_swarm.util.tracking import SQLiteTracker


@pytest.fixture
def sqlite_tracker():
tracker = SQLiteTracker(":memory:")
yield tracker


def test_sqlite_track_and_get_total_tokens(sqlite_tracker):
usage = Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15)
sqlite_tracker.track_usage(
usage, "test_assistant", "test_thread", "gpt-4o", "sender", "recipient"
)
totals = sqlite_tracker.get_total_tokens()
assert totals == usage


def test_sqlite_multiple_entries(sqlite_tracker):
# Insert multiple usage entries
usages = [
Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30),
]
for u in usages:
sqlite_tracker.track_usage(
u, "assistant", "thread", "gpt-4o", "sender", "recipient"
)

totals = sqlite_tracker.get_total_tokens()
assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45)

0 comments on commit c8e91a0

Please sign in to comment.