From 7261503d7634b10f51d4a6beaefef4b27a9f8406 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Wed, 18 Dec 2024 16:25:33 +0100 Subject: [PATCH] Fix tests --- test/interop/pydantic_ai/test_pydantic_ai.py | 2 +- test/interop/test_helpers.py | 11 +++++++++-- test/interop/test_interoperability.py | 6 +++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/test/interop/pydantic_ai/test_pydantic_ai.py b/test/interop/pydantic_ai/test_pydantic_ai.py index 4bf3304366..3e29ceec57 100644 --- a/test/interop/pydantic_ai/test_pydantic_ai.py +++ b/test/interop/pydantic_ai/test_pydantic_ai.py @@ -162,7 +162,7 @@ def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) - self.tool = self.pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player) def test_expected_tools(self) -> None: - config_list = [{"model": "gpt-4o", "api_key": os.environ["OPENAI_API_KEY"]}] + config_list = [{"model": "gpt-4o", "api_key": "abc"}] chatbot = AssistantAgent( name="chatbot", llm_config={"config_list": config_list}, diff --git a/test/interop/test_helpers.py b/test/interop/test_helpers.py index f3a5ad09ef..76f52a85f4 100644 --- a/test/interop/test_helpers.py +++ b/test/interop/test_helpers.py @@ -37,9 +37,11 @@ def test_find_classes_implementing_protocol(self) -> None: if sys.version_info >= (3, 9): from autogen.interop.langchain import LangchainInteroperability + from autogen.interop.pydantic_ai import PydanticAIInteroperability assert LangchainInteroperability in actual - expected_count += 1 + assert PydanticAIInteroperability in actual + expected_count += 2 assert len(actual) == expected_count @@ -53,8 +55,13 @@ def test_get_all_interoperability_classes(self) -> None: if sys.version_info >= (3, 10) and sys.version_info < (3, 13): from autogen.interop.crewai import CrewAIInteroperability from autogen.interop.langchain import LangchainInteroperability + from autogen.interop.pydantic_ai import PydanticAIInteroperability - assert actual == {"crewai": CrewAIInteroperability, "langchain": LangchainInteroperability} + assert actual == { + "pydanticai": PydanticAIInteroperability, + "crewai": CrewAIInteroperability, + "langchain": LangchainInteroperability, + } if (sys.version_info >= (3, 9) and sys.version_info < (3, 10)) and sys.version_info >= (3, 13): from autogen.interop.langchain import LangchainInteroperability diff --git a/test/interop/test_interoperability.py b/test/interop/test_interoperability.py index dacdc16c87..8dba1c763c 100644 --- a/test/interop/test_interoperability.py +++ b/test/interop/test_interoperability.py @@ -20,13 +20,13 @@ def test_supported_types(self) -> None: assert actual == [] if sys.version_info >= (3, 9) and sys.version_info < (3, 10): - assert actual == ["langchain"] + assert actual == ["langchain", "pydanticai"] if sys.version_info >= (3, 10) and sys.version_info < (3, 13): - assert actual == ["crewai", "langchain"] + assert actual == ["crewai", "langchain", "pydanticai"] if sys.version_info >= (3, 13): - assert actual == ["langchain"] + assert actual == ["langchain", "pydanticai"] def test_register_interoperability_class(self) -> None: org_interoperability_classes = Interoperability._interoperability_classes