diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index deea53b7..76d9a217 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -17,6 +17,7 @@ ToolFactory, ) from agency_swarm.tools.oai.FileSearch import FileSearchConfig +from agency_swarm.util.constants import DEFAULT_MODEL from agency_swarm.util.oai import get_openai_client from agency_swarm.util.openapi import validate_openapi_spec from agency_swarm.util.shared_state import SharedState @@ -94,7 +95,7 @@ def __init__( api_params: Dict[str, Dict[str, str]] = None, file_ids: List[str] = None, metadata: Dict[str, str] = None, - model: str = "gpt-4o-2024-08-06", + model: str = DEFAULT_MODEL, validation_attempts: int = 1, max_prompt_tokens: int = None, max_completion_tokens: int = None, diff --git a/agency_swarm/util/constants.py b/agency_swarm/util/constants.py new file mode 100644 index 00000000..cd951782 --- /dev/null +++ b/agency_swarm/util/constants.py @@ -0,0 +1,2 @@ +DEFAULT_MODEL = "gpt-4o-2024-08-06" +DEFAULT_MODEL_MINI = "gpt-4o-mini" diff --git a/agency_swarm/util/streaming/agency_event_handler.py b/agency_swarm/util/streaming/agency_event_handler.py index b38d23ac..1563a4b7 100644 --- a/agency_swarm/util/streaming/agency_event_handler.py +++ b/agency_swarm/util/streaming/agency_event_handler.py @@ -4,6 +4,7 @@ from openai.types.beta.threads.runs.run_step import RunStep from typing_extensions import override +from agency_swarm.util.constants import DEFAULT_MODEL from agency_swarm.util.oai import get_tracker @@ -43,7 +44,7 @@ def on_run_step_done(cls, run_step: RunStep) -> None: """ if run_step.usage: tracker = get_tracker() - model = cls.agent.model if cls.agent else "gpt-4o" + model = cls.agent.model if cls.agent else DEFAULT_MODEL tracker.track_usage( usage=run_step.usage, assistant_id=run_step.assistant_id, diff --git a/agency_swarm/util/validators.py b/agency_swarm/util/validators.py index 86130f44..3ac2f4bc 100644 --- a/agency_swarm/util/validators.py +++ b/agency_swarm/util/validators.py @@ -3,6 +3,7 @@ from openai import OpenAI from pydantic import BaseModel, Field +from agency_swarm.util.constants import DEFAULT_MODEL_MINI from agency_swarm.util.oai import get_openai_client @@ -29,7 +30,7 @@ def llm_validator( statement: str, client: OpenAI = None, allow_override: bool = False, - model: str = "gpt-4o-mini", + model: str = DEFAULT_MODEL_MINI, temperature: float = 0, ) -> Callable[[str], str]: """ diff --git a/tests/test_agency.py b/tests/test_agency.py index d39fd566..0efb258e 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -51,6 +51,7 @@ def setUpClass(cls): cls.agent2 = None cls.agency = None cls.client = get_openai_client() + cls.client.timeout = 60.0 # testing loading agents from db cls.loaded_thread_ids = {} diff --git a/tests/test_communication.py b/tests/test_communication.py index 205ef8c2..be62b02b 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -66,7 +66,7 @@ def test_send_message_swarm(self): len(main_thread.get_messages()) >= 4 ) # sometimes run does not cancel immediately, so there might be 5 messages - def test_send_message_double_recepient_error(self): + def test_send_message_double_recipient_error(self): ceo = Agent( name="CEO", description="Responsible for client communication, task planning and management.", diff --git a/tests/test_tracking.py b/tests/tracking/test_langfuse_tracking.py similarity index 80% rename from tests/test_tracking.py rename to tests/tracking/test_langfuse_tracking.py index 97587388..32df0cd2 100644 --- a/tests/test_tracking.py +++ b/tests/tracking/test_langfuse_tracking.py @@ -3,13 +3,7 @@ import pytest from openai.types.beta.threads.runs.run_step import Usage -from agency_swarm.util.tracking import LangfuseTracker, SQLiteTracker - - -@pytest.fixture -def sqlite_tracker(): - tracker = SQLiteTracker(":memory:") - yield tracker +from agency_swarm.util.tracking import LangfuseTracker @pytest.fixture @@ -27,22 +21,6 @@ def test_sqlite_track_and_get_total_tokens(sqlite_tracker): assert totals == usage -def test_sqlite_multiple_entries(sqlite_tracker): - # Insert multiple usage entries - usages = [ - Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), - Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30), - ] - for u in usages: - sqlite_tracker.track_usage( - u, "assistant", "thread", "gpt-4o", "sender", "recipient" - ) - - totals = sqlite_tracker.get_total_tokens() - # Expected totals: prompt=30, completion=15, total=45 - assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45) - - @patch("agency_swarm.util.tracking.langfuse_tracker.Langfuse") def test_langfuse_track_usage(mock_langfuse, langfuse_tracker): # Create mock instance and set it as the client @@ -112,7 +90,6 @@ def test_langfuse_get_total_tokens_multiple(mock_langfuse, langfuse_tracker): langfuse_tracker.client = mock_langfuse_instance # Set the mocked client totals = langfuse_tracker.get_total_tokens() - # Expected totals: prompt=30, completion=15, total=45 assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45) diff --git a/tests/tracking/test_sqlite_tracking.py b/tests/tracking/test_sqlite_tracking.py new file mode 100644 index 00000000..3247603a --- /dev/null +++ b/tests/tracking/test_sqlite_tracking.py @@ -0,0 +1,34 @@ +import pytest +from openai.types.beta.threads.runs.run_step import Usage + +from agency_swarm.util.tracking import SQLiteTracker + + +@pytest.fixture +def sqlite_tracker(): + tracker = SQLiteTracker(":memory:") + yield tracker + + +def test_sqlite_track_and_get_total_tokens(sqlite_tracker): + usage = Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15) + sqlite_tracker.track_usage( + usage, "test_assistant", "test_thread", "gpt-4o", "sender", "recipient" + ) + totals = sqlite_tracker.get_total_tokens() + assert totals == usage + + +def test_sqlite_multiple_entries(sqlite_tracker): + # Insert multiple usage entries + usages = [ + Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30), + ] + for u in usages: + sqlite_tracker.track_usage( + u, "assistant", "thread", "gpt-4o", "sender", "recipient" + ) + + totals = sqlite_tracker.get_total_tokens() + assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45)