diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index ffd6923721..b2f22ce9c5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2600,7 +2600,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) self.client = OpenAIWrapper(**self.llm_config) - def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): + def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool): """update a tool_signature in the LLM configuration for tool_call. Args: diff --git a/autogen/interop/interoperability.py b/autogen/interop/interoperability.py index b68a8ce4f8..2c9ed1ffbf 100644 --- a/autogen/interop/interoperability.py +++ b/autogen/interop/interoperability.py @@ -7,6 +7,8 @@ from .helpers import get_all_interoperability_classes from .interoperable import Interoperable +__all__ = ["Interoperable"] + class Interoperability: _interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes() diff --git a/autogen/interop/pydantic_ai/__init__.py b/autogen/interop/pydantic_ai/__init__.py new file mode 100644 index 0000000000..55d52347ce --- /dev/null +++ b/autogen/interop/pydantic_ai/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import sys + +if sys.version_info < (3, 9): + raise ImportError("This submodule is only supported for Python versions 3.9 and above") + +try: + import pydantic_ai.tools +except ImportError: + raise ImportError( + "Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]" + ) + +from .pydantic_ai import PydanticAIInteroperability + +__all__ = ["PydanticAIInteroperability"] diff --git a/autogen/interop/pydantic_ai/pydantic_ai.py b/autogen/interop/pydantic_ai/pydantic_ai.py new file mode 100644 index 0000000000..27a146704c --- /dev/null +++ b/autogen/interop/pydantic_ai/pydantic_ai.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + + +from functools import wraps +from inspect import signature +from typing import Any, Callable, Optional + +from pydantic_ai import RunContext +from pydantic_ai.tools import Tool as PydanticAITool + +from ...tools import PydanticAITool as AG2PydanticAITool +from ..interoperability import Interoperable + +__all__ = ["PydanticAIInteroperability"] + + +class PydanticAIInteroperability(Interoperable): + @staticmethod + def inject_params( # type: ignore[no-any-unimported] + ctx: Optional[RunContext[Any]], + tool: PydanticAITool, + ) -> Callable[..., Any]: + max_retries = tool.max_retries if tool.max_retries is not None else 1 + f = tool.function + + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if tool.current_retry >= max_retries: + raise ValueError(f"{tool.name} failed after {max_retries} retries") + + try: + if ctx is not None: + kwargs.pop("ctx", None) + ctx.retry = tool.current_retry + result = f(**kwargs, ctx=ctx) + else: + result = f(**kwargs) + tool.current_retry = 0 + except Exception as e: + tool.current_retry += 1 + raise e + + return result + + sig = signature(f) + if ctx is not None: + new_params = [param for name, param in sig.parameters.items() if name != "ctx"] + else: + new_params = list(sig.parameters.values()) + + wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined] + + return wrapper + + def convert_tool(self, tool: Any, deps: Any = None) -> AG2PydanticAITool: + if not isinstance(tool, PydanticAITool): + raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}") + + # needed for type checking + pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported] + + if deps is not None: + ctx = RunContext( + deps=deps, + retry=0, + # All messages send to or returned by a model. + # This is mostly used on pydantic_ai Agent level. + messages=None, # TODO: check in the future if this is needed on Tool level + tool_name=pydantic_ai_tool.name, + ) + else: + ctx = None + + func = PydanticAIInteroperability.inject_params( + ctx=ctx, + tool=pydantic_ai_tool, + ) + + return AG2PydanticAITool( + name=pydantic_ai_tool.name, + description=pydantic_ai_tool.description, + func=func, + parameters_json_schema=pydantic_ai_tool._parameters_json_schema, + ) diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index 7c02afcdca..201fb2ee55 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -42,7 +42,7 @@ CEREBRAS_PRICING_1K = { # Convert pricing per million to per thousand tokens. "llama3.1-8b": (0.10 / 1000, 0.10 / 1000), - "llama3.1-70b": (0.60 / 1000, 0.60 / 1000), + "llama-3.3-70b": (0.85 / 1000, 1.20 / 1000), } diff --git a/autogen/tools/__init__.py b/autogen/tools/__init__.py index 3451ac3111..d787ae1a60 100644 --- a/autogen/tools/__init__.py +++ b/autogen/tools/__init__.py @@ -1,3 +1,4 @@ +from .pydantic_ai_tool import PydanticAITool from .tool import Tool -__all__ = ["Tool"] +__all__ = ["PydanticAITool", "Tool"] diff --git a/autogen/tools/pydantic_ai_tool.py b/autogen/tools/pydantic_ai_tool.py new file mode 100644 index 0000000000..5c999eba3a --- /dev/null +++ b/autogen/tools/pydantic_ai_tool.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict + +from autogen.agentchat.conversable_agent import ConversableAgent + +from .tool import Tool + +__all__ = ["PydanticAITool"] + + +class PydanticAITool(Tool): + def __init__( + self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any] + ) -> None: + super().__init__(name, description, func) + self._func_schema = { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters_json_schema, + }, + } + + def register_for_llm(self, agent: ConversableAgent) -> None: + agent.update_tool_signature(self._func_schema, is_remove=False) diff --git a/notebook/tools_crewai_tools_integration.ipynb b/notebook/tools_crewai_tools_integration.ipynb index f32d2c2591..07bdd0e341 100644 --- a/notebook/tools_crewai_tools_integration.ipynb +++ b/notebook/tools_crewai_tools_integration.ipynb @@ -50,7 +50,7 @@ "from crewai_tools import FileWriterTool, ScrapeWebsiteTool\n", "\n", "from autogen import AssistantAgent, UserProxyAgent\n", - "from autogen.interoperability.crewai import CrewAIInteroperability" + "from autogen.interop.crewai import CrewAIInteroperability" ] }, { @@ -168,7 +168,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -182,7 +182,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/notebook/tools_langchain_tools_integration.ipynb b/notebook/tools_langchain_tools_integration.ipynb index 3fb4d68594..48a1d0986a 100644 --- a/notebook/tools_langchain_tools_integration.ipynb +++ b/notebook/tools_langchain_tools_integration.ipynb @@ -52,7 +52,7 @@ "from langchain_community.utilities import WikipediaAPIWrapper\n", "\n", "from autogen import AssistantAgent, UserProxyAgent\n", - "from autogen.interoperability.langchain import LangchainInteroperability" + "from autogen.interop.langchain import LangchainInteroperability" ] }, { @@ -112,7 +112,7 @@ "ag2_tool = langchain_interop.convert_tool(langchain_tool)\n", "\n", "ag2_tool.register_for_execution(user_proxy)\n", - "ag2_tool.register_for_llm(chatbot)\n" + "ag2_tool.register_for_llm(chatbot)" ] }, { diff --git a/notebook/tools_pydantic_ai_tools_integration.ipynb b/notebook/tools_pydantic_ai_tools_integration.ipynb new file mode 100644 index 0000000000..edf2170135 --- /dev/null +++ b/notebook/tools_pydantic_ai_tools_integration.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Integrating PydanticAI Tools with the AG2 Framework\n", + "\n", + "In this tutorial, we demonstrate how to integrate [PydanticAI Tools](https://ai.pydantic.dev/tools/) into the AG2 framework. This process enables smooth interoperability between the two systems, allowing developers to leverage PydanticAI's powerful tools within AG2's flexible agent-based architecture. By the end of this guide, you will understand how to configure agents, convert PydanticAI tools for use in AG2, and validate the integration with a practical example.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "To integrate LangChain tools into the AG2 framework, install the required dependencies:\n", + "\n", + "```bash\n", + "pip install ag2[interop-pydantic-ai]\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n", + "\n", + "Import necessary modules and tools.\n", + "- `BaseModel`: Used to define data structures for tool inputs and outputs.\n", + "- `RunContext`: Provides context during the execution of tools.\n", + "- `PydanticAITool`: Represents a tool in the PydanticAI framework.\n", + "- `AssistantAgent` and `UserProxyAgent`: Agents that facilitate communication in the AG2 framework.\n", + "- `PydanticAIInteroperability`: A bridge for integrating PydanticAI tools with the AG2 framework." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import Optional\n", + "\n", + "from pydantic import BaseModel\n", + "from pydantic_ai import RunContext\n", + "from pydantic_ai.tools import Tool as PydanticAITool\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.interop.pydantic_ai import PydanticAIInteroperability" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agent Configuration\n", + "\n", + "Configure the agents for the interaction.\n", + "- `config_list` defines the LLM configurations, including the model and API key.\n", + "- `UserProxyAgent` simulates user inputs without requiring actual human interaction (set to `NEVER`).\n", + "- `AssistantAgent` represents the AI agent, configured with the LLM settings." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "config_list = [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}]\n", + "user_proxy = UserProxyAgent(\n", + " name=\"User\",\n", + " human_input_mode=\"NEVER\",\n", + ")\n", + "\n", + "chatbot = AssistantAgent(\n", + " name=\"chatbot\",\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Integration\n", + "\n", + "Integrate the PydanticAI tool with AG2.\n", + "\n", + "- Define a `Player` model using `BaseModel` to structure the input data.\n", + "- Use `RunContext` to securely inject dependencies (like the `Player` instance) into the tool function without exposing them to the LLM.\n", + "- Implement `get_player` to define the tool's functionality, accessing `ctx.deps` for injected data.\n", + "- Convert the tool to an AG2-compatible format with `PydanticAIInteroperability` and register it for execution and LLM communication.\n", + "- Convert the PydanticAI tool into an AG2-compatible format using `convert_tool`.\n", + "- Register the tool for both execution and communication with the LLM by associating it with the `user_proxy` and `chatbot`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class Player(BaseModel):\n", + " name: str\n", + " age: int\n", + "\n", + "\n", + "def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -> str: # type: ignore[valid-type]\n", + " \"\"\"Get the player's name.\n", + "\n", + " Args:\n", + " additional_info: Additional information which can be used.\n", + " \"\"\"\n", + " return f\"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}\" # type: ignore[attr-defined]\n", + "\n", + "\n", + "pydantic_ai_interop = PydanticAIInteroperability()\n", + "pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True)\n", + "\n", + "# player will be injected as a dependency\n", + "player = Player(name=\"Luka\", age=25)\n", + "ag2_tool = pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player)\n", + "\n", + "ag2_tool.register_for_execution(user_proxy)\n", + "ag2_tool.register_for_llm(chatbot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initiate a conversation between the `UserProxyAgent` and the `AssistantAgent`.\n", + "\n", + "- Use the `initiate_chat` method to send a message from the `user_proxy` to the `chatbot`.\n", + "- In this example, the user requests the chatbot to retrieve player information, providing \"goal keeper\" as additional context.\n", + "- The `Player` instance is securely injected into the tool using `RunContext`, ensuring the chatbot can retrieve and use this data during the interaction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy.initiate_chat(\n", + " recipient=chatbot, message=\"Get player, for additional information use 'goal keeper'\", max_turns=3\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 88e6a62e96..63cade834a 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ interop_crewai = ["crewai[tools]>=0.86,<1; python_version>='3.10' and python_version<'3.13'"] interop_langchain = ["langchain-community>=0.3.12,<1; python_version>='3.9'"] +interop_pydantic_ai = ["pydantic-ai>=0.0.13,<1; python_version>='3.9'"] if current_os in ["Windows", "Darwin"]: retrieve_chat_pgvector.extend(["psycopg[binary]>=3.1.18"]) @@ -127,7 +128,8 @@ "bedrock": ["boto3>=1.34.149"], "interop-crewai": interop_crewai, "interop-langchain": interop_langchain, - "interop": interop_crewai + interop_langchain, + "interop-pydantic-ai": interop_pydantic_ai, + "interop": interop_crewai + interop_langchain + interop_pydantic_ai, "neo4j": neo4j, } diff --git a/setup_ag2.py b/setup_ag2.py index f0015dda36..0b8a9dfec7 100644 --- a/setup_ag2.py +++ b/setup_ag2.py @@ -56,6 +56,7 @@ "bedrock": ["pyautogen[bedrock]==" + __version__], "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], + "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], "interop": ["pyautogen[interop]==" + __version__], "neo4j": ["pyautogen[neo4j]==" + __version__], }, diff --git a/setup_autogen.py b/setup_autogen.py index 520b7294d5..93690e4bea 100644 --- a/setup_autogen.py +++ b/setup_autogen.py @@ -56,6 +56,7 @@ "bedrock": ["pyautogen[bedrock]==" + __version__], "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], + "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], "interop": ["pyautogen[interop]==" + __version__], "neo4j": ["pyautogen[neo4j]==" + __version__], }, diff --git a/test/interop/pydantic_ai/__init__.py b/test/interop/pydantic_ai/__init__.py new file mode 100644 index 0000000000..bcd5401d54 --- /dev/null +++ b/test/interop/pydantic_ai/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/interop/pydantic_ai/test_pydantic_ai.py b/test/interop/pydantic_ai/test_pydantic_ai.py new file mode 100644 index 0000000000..4bf3304366 --- /dev/null +++ b/test/interop/pydantic_ai/test_pydantic_ai.py @@ -0,0 +1,221 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +import sys +import unittest +from inspect import signature +from typing import Any, Dict, Optional + +import pytest +from conftest import reason, skip_openai +from pydantic import BaseModel + +from autogen import AssistantAgent, UserProxyAgent +from autogen.interop import Interoperable + +if sys.version_info >= (3, 9): + from pydantic_ai import RunContext + from pydantic_ai.tools import Tool as PydanticAITool + + from autogen.interop.pydantic_ai import PydanticAIInteroperability +else: + RunContext = unittest.mock.MagicMock() + PydanticAITool = unittest.mock.MagicMock() + PydanticAIInteroperability = unittest.mock.MagicMock() + + +# skip if python version is not >= 3.9 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" +) +class TestPydanticAIInteroperabilityWithotContext: + @pytest.fixture(autouse=True) + def setup(self) -> None: + def roll_dice() -> str: + """Roll a six-sided dice and return the result.""" + return str(random.randint(1, 6)) + + self.pydantic_ai_interop = PydanticAIInteroperability() + pydantic_ai_tool = PydanticAITool(roll_dice, max_retries=3) + self.tool = self.pydantic_ai_interop.convert_tool(pydantic_ai_tool) + + def test_type_checks(self) -> None: + # mypy should fail if the type checks are not correct + interop: Interoperable = self.pydantic_ai_interop + # runtime check + assert isinstance(interop, Interoperable) + + def test_init(self) -> None: + assert isinstance(self.pydantic_ai_interop, Interoperable) + + def test_convert_tool(self) -> None: + assert self.tool.name == "roll_dice" + assert self.tool.description == "Roll a six-sided dice and return the result." + assert self.tool.func() in ["1", "2", "3", "4", "5", "6"] + + @pytest.mark.skipif(skip_openai, reason=reason) + def test_with_llm(self) -> None: + config_list = [{"model": "gpt-4o", "api_key": os.environ["OPENAI_API_KEY"]}] + user_proxy = UserProxyAgent( + name="User", + human_input_mode="NEVER", + ) + + chatbot = AssistantAgent( + name="chatbot", + llm_config={"config_list": config_list}, + ) + + self.tool.register_for_execution(user_proxy) + self.tool.register_for_llm(chatbot) + + user_proxy.initiate_chat(recipient=chatbot, message="roll a dice", max_turns=2) + + for message in user_proxy.chat_messages[chatbot]: + if "tool_responses" in message: + assert message["tool_responses"][0]["content"] in ["1", "2", "3", "4", "5", "6"] + return + + assert False, "No tool response found in chat messages" + + +class TestPydanticAIInteroperabilityDependencyInjection: + + def test_dependency_injection(self) -> None: + def f( + ctx: RunContext[int], # type: ignore[valid-type] + city: str, + date: str, + ) -> str: + """Random function for testing.""" + return f"{city} {date} {ctx.deps}" # type: ignore[attr-defined] + + ctx = RunContext( + deps=123, + retry=0, + messages=None, + tool_name=f.__name__, + ) + pydantic_ai_tool = PydanticAITool(f, takes_ctx=True) + g = PydanticAIInteroperability.inject_params( + ctx=ctx, + tool=pydantic_ai_tool, + ) + assert list(signature(g).parameters.keys()) == ["city", "date"] + kwargs: Dict[str, Any] = {"city": "Zagreb", "date": "2021-01-01"} + assert g(**kwargs) == "Zagreb 2021-01-01 123" + + def test_dependency_injection_with_retry(self) -> None: + def f( + ctx: RunContext[int], # type: ignore[valid-type] + city: str, + date: str, + ) -> str: + """Random function for testing.""" + raise ValueError("Retry") + + ctx = RunContext( + deps=123, + retry=0, + messages=None, + tool_name=f.__name__, + ) + + pydantic_ai_tool = PydanticAITool(f, takes_ctx=True, max_retries=3) + g = PydanticAIInteroperability.inject_params( + ctx=ctx, + tool=pydantic_ai_tool, + ) + + for i in range(3): + with pytest.raises(ValueError, match="Retry"): + g(city="Zagreb", date="2021-01-01") + assert pydantic_ai_tool.current_retry == i + 1 + assert ctx.retry == i + + with pytest.raises(ValueError, match="f failed after 3 retries"): + g(city="Zagreb", date="2021-01-01") + assert pydantic_ai_tool.current_retry == 3 + + +class TestPydanticAIInteroperabilityWithContext: + @pytest.fixture(autouse=True) + def setup(self) -> None: + class Player(BaseModel): + name: str + age: int + + def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -> str: # type: ignore[valid-type] + """Get the player's name. + + Args: + additional_info: Additional information which can be used. + """ + return f"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}" # type: ignore[attr-defined] + + self.pydantic_ai_interop = PydanticAIInteroperability() + pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True) + player = Player(name="Luka", age=25) + self.tool = self.pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player) + + def test_expected_tools(self) -> None: + config_list = [{"model": "gpt-4o", "api_key": os.environ["OPENAI_API_KEY"]}] + chatbot = AssistantAgent( + name="chatbot", + llm_config={"config_list": config_list}, + ) + self.tool.register_for_llm(chatbot) + + expected_tools = [ + { + "type": "function", + "function": { + "name": "get_player", + "description": "Get the player's name.", + "parameters": { + "properties": { + "additional_info": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "Additional information which can be used.", + "title": "Additional Info", + } + }, + "required": ["additional_info"], + "type": "object", + "additionalProperties": False, + }, + }, + } + ] + + assert chatbot.llm_config["tools"] == expected_tools # type: ignore[index] + + @pytest.mark.skipif(skip_openai, reason=reason) + def test_with_llm(self) -> None: + config_list = [{"model": "gpt-4o", "api_key": os.environ["OPENAI_API_KEY"]}] + user_proxy = UserProxyAgent( + name="User", + human_input_mode="NEVER", + ) + + chatbot = AssistantAgent( + name="chatbot", + llm_config={"config_list": config_list}, + ) + + self.tool.register_for_execution(user_proxy) + self.tool.register_for_llm(chatbot) + + user_proxy.initiate_chat( + recipient=chatbot, message="Get player, for additional information use 'goal keeper'", max_turns=3 + ) + + for message in user_proxy.chat_messages[chatbot]: + if "tool_responses" in message: + assert message["tool_responses"][0]["content"] == "Name: Luka, Age: 25, Additional info: goal keeper" + return + + assert False, "No tool response found in chat messages" diff --git a/test/oai/test_cerebras.py b/test/oai/test_cerebras.py index 202887d2e1..b9dc2c786b 100644 --- a/test/oai/test_cerebras.py +++ b/test/oai/test_cerebras.py @@ -142,7 +142,7 @@ def test_cost_calculation(mock_response): choices=[{"message": "Test message 1"}], usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800}, cost=None, - model="llama3.1-70b", + model="llama-3.3-70b", ) calculated_cost = calculate_cerebras_cost( response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model @@ -166,7 +166,7 @@ def test_create_response(mock_chat, cerebras_client): MagicMock(finish_reason="stop", message=MagicMock(content="Example Cerebras response", tool_calls=None)) ] mock_cerebras_response.id = "mock_cerebras_response_id" - mock_cerebras_response.model = "llama3.1-70b" + mock_cerebras_response.model = "llama-3.3-70b" mock_cerebras_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage mock_chat.return_value = mock_cerebras_response @@ -174,7 +174,7 @@ def test_create_response(mock_chat, cerebras_client): # Test parameters params = { "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], - "model": "llama3.1-70b", + "model": "llama-3.3-70b", } # Call the create method @@ -185,7 +185,7 @@ def test_create_response(mock_chat, cerebras_client): response.choices[0].message.content == "Example Cerebras response" ), "Response content should match expected output" assert response.id == "mock_cerebras_response_id", "Response ID should match the mocked response ID" - assert response.model == "llama3.1-70b", "Response model should match the mocked response model" + assert response.model == "llama-3.3-70b", "Response model should match the mocked response model" assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" @@ -217,7 +217,7 @@ def test_create_response_with_tool_call(mock_chat, cerebras_client): ) ], id="mock_cerebras_response_id", - model="llama3.1-70b", + model="llama-3.3-70b", usage=MagicMock(prompt_tokens=10, completion_tokens=20), ) @@ -245,7 +245,7 @@ def test_create_response_with_tool_call(mock_chat, cerebras_client): # Call the create method response = cerebras_client.create( - {"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"} + {"messages": cerebras_messages, "tools": converted_functions, "model": "llama-3.3-70b"} ) # Assertions to check if the functions and content are included in the response diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb index b3bf68a6c3..6e48fe692a 100644 --- a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -47,7 +47,7 @@ " \"api_type\": \"cerebras\"\n", " },\n", " {\n", - " \"model\": \"llama3.1-70b\",\n", + " \"model\": \"llama-3.3-70b\",\n", " \"api_key\": \"your Cerebras API Key goes here\",\n", " \"api_type\": \"cerebras\"\n", " }\n", @@ -86,7 +86,7 @@ "```python\n", "[\n", " {\n", - " \"model\": \"llama3.1-70b\",\n", + " \"model\": \"llama-3.3-70b\",\n", " \"api_key\": \"your Cerebras API Key goes here\",\n", " \"api_type\": \"cerebras\"\n", " \"max_tokens\": 10000,\n", @@ -120,7 +120,7 @@ "\n", "from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost\n", "\n", - "config_list = [{\"model\": \"llama3.1-70b\", \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"), \"api_type\": \"cerebras\"}]" + "config_list = [{\"model\": \"llama-3.3-70b\", \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"), \"api_type\": \"cerebras\"}]" ] }, { @@ -270,7 +270,7 @@ "\n", "config_list = [\n", " {\n", - " \"model\": \"llama3.1-70b\",\n", + " \"model\": \"llama-3.3-70b\",\n", " \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n", " \"api_type\": \"cerebras\",\n", " }\n",