Skip to content

Commit

Permalink
Merge pull request #83 from ag2ai/swarmtests
Browse files Browse the repository at this point in the history
Swarm tests and bug fixes
  • Loading branch information
marklysze authored Nov 27, 2024
2 parents b0968a3 + 260082e commit f27f8bb
Show file tree
Hide file tree
Showing 3 changed files with 572 additions and 18 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

SwarmTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
exclude:
- os: macos-latest
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Swarms
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/agentchat/contrib/test_swarm.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
85 changes: 67 additions & 18 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class ON_CONDITION:
agent: "SwarmAgent"
condition: str = ""

# Ensure that agent is a SwarmAgent
def __post_init__(self):
assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent"


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
Expand Down Expand Up @@ -80,7 +84,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
SwarmAgent: Last speaker.
"""
assert isinstance(initial_agent, SwarmAgent), "initial_agent must be a SwarmAgent"
assert all(isinstance(agent, SwarmAgent) for agent in agents), "agents must be a list of SwarmAgents"
assert all(isinstance(agent, SwarmAgent) for agent in agents), "Agents must be a list of SwarmAgents"
# Ensure all agents in hand-off after-works are in the passed in agents list
for agent in agents:
if agent.after_work is not None:
if isinstance(agent.after_work.agent, SwarmAgent):
assert agent.after_work.agent in agents, "Agent in hand-off must be in the agents list"

context_variables = context_variables or {}
if isinstance(messages, str):
Expand Down Expand Up @@ -175,9 +184,12 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
last_message = messages[0]

if "name" in last_message:
if "name" in swarm_agent_names:
if last_message["name"] in swarm_agent_names:
# If there's a name in the message and it's a swarm agent, use that
last_agent = groupchat.agent_by_name(name=last_message["name"])
elif user_agent and last_message["name"] == user_agent.name:
# If the user agent is passed in and is the first message
last_agent = user_agent
else:
raise ValueError(f"Invalid swarm agent name in last message: {last_message['name']}")
else:
Expand Down Expand Up @@ -260,9 +272,13 @@ def __init__(
)

if isinstance(functions, list):
if not all(isinstance(func, Callable) for func in functions):
raise TypeError("All elements in the functions list must be callable")
self.add_functions(functions)
elif isinstance(functions, Callable):
self.add_single_function(functions)
elif functions is not None:
raise TypeError("Functions must be a callable or a list of callables")

self.after_work = None

Expand Down Expand Up @@ -299,11 +315,18 @@ def transfer_to_agent_name() -> SwarmAgent:
1. register the function with the agent
2. register the schema with the agent, description set to the condition
"""
# Ensure that hand_to is a list or ON_CONDITION or AFTER_WORK
if not isinstance(hand_to, (list, ON_CONDITION, AFTER_WORK)):
raise ValueError("hand_to must be a list of ON_CONDITION or AFTER_WORK")

if isinstance(hand_to, (ON_CONDITION, AFTER_WORK)):
hand_to = [hand_to]

for transit in hand_to:
if isinstance(transit, AFTER_WORK):
assert isinstance(
transit.agent, (AfterWorkOption, SwarmAgent, str, Callable)
), "Invalid After Work value"
self.after_work = transit
elif isinstance(transit, ON_CONDITION):

Expand Down Expand Up @@ -340,8 +363,18 @@ def generate_swarm_tool_reply(

message = messages[-1]
if "tool_calls" in message:
# 1. add context_variables to the tool call arguments
for tool_call in message["tool_calls"]:

tool_calls = len(message["tool_calls"])

# Loop through tool calls individually (so context can be updated after each function call)
next_agent = None
tool_responses_inner = []
contents = []
for index in range(tool_calls):

# 1. add context_variables to the tool call arguments
tool_call = message["tool_calls"][index]

if tool_call["type"] == "function":
function_name = tool_call["function"]["name"]

Expand All @@ -357,20 +390,36 @@ def generate_swarm_tool_reply(
# Update the tool call with new arguments
tool_call["function"]["arguments"] = json.dumps(current_args)

# 2. generate tool calls reply
_, tool_message = self.generate_tool_calls_reply([message])

# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")
if isinstance(content, SwarmResult):
if content.context_variables != {}:
self._context_variables.update(content.context_variables)
if content.agent is not None:
self._next_agent = content.agent
elif isinstance(content, Agent):
self._next_agent = content
tool_response["content"] = str(tool_response["content"])
# Copy the message
message_copy = message.copy()
tool_calls_copy = message_copy["tool_calls"]

# remove all the tool calls except the one at the index
message_copy["tool_calls"] = [tool_calls_copy[index]]

# 2. generate tool calls reply
_, tool_message = self.generate_tool_calls_reply([message_copy])

# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")
if isinstance(content, SwarmResult):
if content.context_variables != {}:
self._context_variables.update(content.context_variables)
if content.agent is not None:
next_agent = content.agent
elif isinstance(content, Agent):
next_agent = content

tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))

self._next_agent = next_agent

# Put the tool responses and content strings back into the response message
# Caters for multiple tool calls
tool_message["tool_responses"] = tool_responses_inner
tool_message["content"] = "\n".join(contents)

return True, tool_message
return False, None
Expand Down
Loading

0 comments on commit f27f8bb

Please sign in to comment.