diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 86a4f39952b8..28c4ccc16f88 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -131,14 +131,19 @@ class AssistantAgent(BaseChatAgent): .. code-block:: python + import asyncio from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.task import MaxMessageTermination - model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent = AssistantAgent(name="assistant", model_client=model_client) + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="assistant", model_client=model_client) - await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2)) + result await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2)) + print(result) + + asyncio.run(main()) The following example demonstrates how to create an assistant agent with @@ -146,6 +151,7 @@ class AssistantAgent(BaseChatAgent): .. code-block:: python + import asyncio from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.task import MaxMessageTermination @@ -155,14 +161,17 @@ async def get_current_time() -> str: return "The current time is 12:00 PM." - model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) + + stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3)) - stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3)) + async for message in stream: + print(message) - async for message in stream: - print(message) + asyncio.run(main()) """ diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py index 0d0a056eab4c..1442dd51358a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py @@ -22,19 +22,25 @@ class TerminationCondition(ABC): .. code-block:: python + import asyncio from autogen_agentchat.teams import MaxTurnsTermination, TextMentionTermination - # Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned. - cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE") - # Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned. - cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE") + async def main() -> None: + # Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned. + cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE") - ... + # Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned. + cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE") - # Reset the termination condition. - await cond1.reset() - await cond2.reset() + # ... + + # Reset the termination condition. + await cond1.reset() + await cond2.reset() + + + asyncio.run(main()) """ @property diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index b2dcdb640297..6fe0be858c39 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -61,46 +61,55 @@ class RoundRobinGroupChat(BaseGroupChat): .. code-block:: python + import asyncio from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.task import StopMessageTermination - model_client = OpenAIChatCompletionClient(model="gpt-4o") + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") - async def get_weather(location: str) -> str: - return f"The weather in {location} is sunny." + async def get_weather(location: str) -> str: + return f"The weather in {location} is sunny." + assistant = AssistantAgent( + "Assistant", + model_client=model_client, + tools=[get_weather], + ) + team = RoundRobinGroupChat([assistant]) + stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) - assistant = AssistantAgent( - "Assistant", - model_client=model_client, - tools=[get_weather], - ) - team = RoundRobinGroupChat([assistant]) - stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination()) - async for message in stream: - print(message) + + asyncio.run(main()) A team with multiple participants: .. code-block:: python + import asyncio from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.task import StopMessageTermination - model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent1 = AssistantAgent("Assistant1", model_client=model_client) - agent2 = AssistantAgent("Assistant2", model_client=model_client) - team = RoundRobinGroupChat([agent1, agent2]) - stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination()) - async for message in stream: - print(message) + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + team = RoundRobinGroupChat([agent1, agent2]) + stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) + + asyncio.run(main()) """ def __init__(self, participants: List[ChatAgent]): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index d91d89e26a5a..ee9c006c3fbb 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -1,12 +1,12 @@ import logging import re -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Sequence from autogen_core.components.models import ChatCompletionClient, SystemMessage from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME from ...base import ChatAgent, TerminationCondition -from ...messages import MultiModalMessage, StopMessage, TextMessage +from ...messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage from .._events import ( GroupChatPublishEvent, GroupChatSelectSpeakerEvent, @@ -20,7 +20,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): """A group chat manager that selects the next speaker using a ChatCompletion - model.""" + model and a custom selector function.""" def __init__( self, @@ -32,6 +32,7 @@ def __init__( model_client: ChatCompletionClient, selector_prompt: str, allow_repeated_speaker: bool, + selector_func: Callable[[Sequence[ChatMessage]], str | None] | None, ) -> None: super().__init__( parent_topic_type, @@ -44,12 +45,24 @@ def __init__( self._selector_prompt = selector_prompt self._previous_speaker: str | None = None self._allow_repeated_speaker = allow_repeated_speaker + self._selector_func = selector_func async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: - """Selects the next speaker in a group chat using a ChatCompletion client. + """Selects the next speaker in a group chat using a ChatCompletion client, + with the selector function as override if it returns a speaker name. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ + + # Use the selector function if provided. + if self._selector_func is not None: + speaker = self._selector_func([msg.agent_message for msg in thread]) + if speaker is not None: + # Skip the model based selection. + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=speaker, source=self.id)) + return speaker + + # Construct the history of the conversation. history_messages: List[str] = [] for event in thread: msg = event.agent_message @@ -160,6 +173,10 @@ class SelectorGroupChat(BaseGroupChat): Must contain '{roles}', '{participants}', and '{history}' to be filled in. allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected consecutively. Defaults to False. + selector_func (Callable[[Sequence[ChatMessage]], str | None], optional): A custom selector + function that takes the conversation history and returns the name of the next speaker. + If provided, this function will be used to override the model to select the next speaker. + If the function returns None, the model will be used to select the next speaker. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -175,43 +192,97 @@ class SelectorGroupChat(BaseGroupChat): from autogen_agentchat.teams import SelectorGroupChat from autogen_agentchat.task import StopMessageTermination - model_client = OpenAIChatCompletionClient(model="gpt-4o") + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") - async def lookup_hotel(location: str) -> str: - return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." - async def lookup_flight(origin: str, destination: str) -> str: - return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + async def book_trip() -> str: + return "Your trip is booked!" + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client) + stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) - async def book_trip() -> str: - return "Your trip is booked!" + import asyncio - travel_advisor = AssistantAgent( - "Travel_Advisor", - model_client, - tools=[book_trip], - description="Helps with travel planning.", - ) - hotel_agent = AssistantAgent( - "Hotel_Agent", - model_client, - tools=[lookup_hotel], - description="Helps with hotel booking.", - ) - flight_agent = AssistantAgent( - "Flight_Agent", - model_client, - tools=[lookup_flight], - description="Helps with flight booking.", - ) - team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client) - stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) - async for message in stream: - print(message) + asyncio.run(main()) + + A team with a custom selector function: + + .. code-block:: python + + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import SelectorGroupChat + from autogen_agentchat.task import TextMentionTermination + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + def check_caculation(x: int, y: int, answer: int) -> str: + if x + y == answer: + return "Correct!" + else: + return "Incorrect!" + + agent1 = AssistantAgent( + "Agent1", + model_client, + description="For calculation", + system_message="Calculate the sum of two numbers", + ) + agent2 = AssistantAgent( + "Agent2", + model_client, + tools=[check_caculation], + description="For checking calculation", + system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'", + ) + + def selector_func(messages): + if len(messages) == 1 or messages[-1].content == "Incorrect!": + return "Agent1" + if messages[-1].source == "Agent1": + return "Agent2" + return None + + team = SelectorGroupChat([agent1, agent2], model_client=model_client, selector_func=selector_func) + + stream = team.run_stream("What is 1 + 1?", termination_condition=TextMentionTermination("Correct!")) + async for message in stream: + print(message) + + + import asyncio + + asyncio.run(main()) """ def __init__( @@ -219,7 +290,6 @@ def __init__( participants: List[ChatAgent], model_client: ChatCompletionClient, *, - termination_condition: TerminationCondition | None = None, selector_prompt: str = """You are in a role play game. The following roles are available: {roles}. Read the following conversation. Then select the next role from {participants} to play. Only return the role. @@ -229,6 +299,7 @@ def __init__( Read the above conversation. Then select the next role from {participants} to play. Only return the role. """, allow_repeated_speaker: bool = False, + selector_func: Callable[[Sequence[ChatMessage]], str | None] | None = None, ): super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager) # Validate the participants. @@ -244,6 +315,7 @@ def __init__( self._selector_prompt = selector_prompt self._model_client = model_client self._allow_repeated_speaker = allow_repeated_speaker + self._selector_func = selector_func def _create_group_chat_manager_factory( self, @@ -262,4 +334,5 @@ def _create_group_chat_manager_factory( self._model_client, self._selector_prompt, self._allow_repeated_speaker, + self._selector_func, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 0afc41afe98b..fcaee4d80c7e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -61,28 +61,34 @@ class Swarm(BaseGroupChat): .. code-block:: python + import asyncio from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import Swarm from autogen_agentchat.task import MaxMessageTermination - model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent1 = AssistantAgent( - "Alice", - model_client=model_client, - handoffs=["Bob"], - system_message="You are Alice and you only answer questions about yourself.", - ) - agent2 = AssistantAgent( - "Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January." - ) + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent( + "Alice", + model_client=model_client, + handoffs=["Bob"], + system_message="You are Alice and you only answer questions about yourself.", + ) + agent2 = AssistantAgent( + "Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January." + ) + + team = Swarm([agent1, agent2]) + + stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + async for message in stream: + print(message) - team = Swarm([agent1, agent2]) - stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) - async for message in stream: - print(message) + asyncio.run(main()) """ def __init__(self, participants: List[ChatAgent]): diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 72fdc6bcd68a..1ff4b124994b 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -493,6 +493,54 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte index += 1 +@pytest.mark.asyncio +async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id2", + choices=[ + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant")) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent1 = _EchoAgent("agent1", description="echo agent 1") + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + agent4 = _EchoAgent("agent4", description="echo agent 4") + + def _select_agent(messages: Sequence[ChatMessage]) -> str | None: + if len(messages) == 0: + return "agent1" + elif messages[-1].source == "agent1": + return "agent2" + elif messages[-1].source == "agent2": + return None + elif messages[-1].source == "agent3": + return "agent4" + else: + return "agent1" + + team = SelectorGroupChat( + participants=[agent1, agent2, agent3, agent4], + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + selector_func=_select_agent, + ) + result = await team.run("task", termination_condition=MaxMessageTermination(6)) + assert len(result.messages) == 6 + assert result.messages[1].source == "agent1" + assert result.messages[2].source == "agent2" + assert result.messages[3].source == "agent3" + assert result.messages[4].source == "agent4" + assert result.messages[5].source == "agent1" + + class _HandOffAgent(BaseChatAgent): def __init__(self, name: str, description: str, next_agent: str) -> None: super().__init__(name, description)