diff --git a/langchain_benchmarks/tool_usage/agents/base.py b/langchain_benchmarks/tool_usage/agents/base.py new file mode 100644 index 00000000..aafdba8d --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/base.py @@ -0,0 +1,11 @@ +import abc + +from langchain_core.runnables import Runnable + + +class AgentFactory(abc.ABC): + """Abstract class for agent factory""" + + @abc.abstractmethod + def __call__(self) -> Runnable: + """Create a new agent""" diff --git a/langchain_benchmarks/tool_usage/agents/runnable_agent.py b/langchain_benchmarks/tool_usage/agents/runnable_agent.py index 8b130f55..b6f76b7b 100644 --- a/langchain_benchmarks/tool_usage/agents/runnable_agent.py +++ b/langchain_benchmarks/tool_usage/agents/runnable_agent.py @@ -10,9 +10,10 @@ from langchain_benchmarks.schema import ToolUsageTask from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter +from langchain_benchmarks.tool_usage.agents.base import AgentFactory -class CustomRunnableAgentFactory: +class CustomRunnableAgentFactory(AgentFactory): """A factory for creating tool using agents. A factory for agents that do not leverage any special JSON mode for diff --git a/langchain_benchmarks/tool_usage/agents/tool_using_agent.py b/langchain_benchmarks/tool_usage/agents/tool_using_agent.py index 78672e4a..93653255 100644 --- a/langchain_benchmarks/tool_usage/agents/tool_using_agent.py +++ b/langchain_benchmarks/tool_usage/agents/tool_using_agent.py @@ -12,9 +12,10 @@ from langchain_benchmarks.rate_limiting import RateLimiter, with_rate_limit from langchain_benchmarks.schema import ToolUsageTask from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter +from langchain_benchmarks.tool_usage.agents.base import AgentFactory -class StandardAgentFactory: +class StandardAgentFactory(AgentFactory): """A standard agent factory. Use this factory with chat models that support the standard LangChain tool diff --git a/tests/unit_tests/tool_usage/test_public_api.py b/tests/unit_tests/tool_usage/test_public_api.py index 1f422366..00110722 100644 --- a/tests/unit_tests/tool_usage/test_public_api.py +++ b/tests/unit_tests/tool_usage/test_public_api.py @@ -6,5 +6,11 @@ def test_public_api() -> None: # This test will also fail if __all__ is not sorted. # Please keep it sorted! assert __all__ == sorted( - ["apply_agent_executor_adapter", "get_eval_config"], key=str.lower + [ + "apply_agent_executor_adapter", + "get_eval_config", + "CustomRunnableAgentFactory", + "StandardAgentFactory", + ], + key=str.lower, )