Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swarm tests and bug fixes #83

Merged
merged 6 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

SwarmTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
exclude:
- os: macos-latest
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Swarms
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/agentchat/contrib/test_swarm.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
85 changes: 67 additions & 18 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class ON_CONDITION:
agent: "SwarmAgent"
condition: str = ""

# Ensure that agent is a SwarmAgent
def __post_init__(self):
assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent"


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
Expand Down Expand Up @@ -80,7 +84,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
SwarmAgent: Last speaker.
"""
assert isinstance(initial_agent, SwarmAgent), "initial_agent must be a SwarmAgent"
assert all(isinstance(agent, SwarmAgent) for agent in agents), "agents must be a list of SwarmAgents"
assert all(isinstance(agent, SwarmAgent) for agent in agents), "Agents must be a list of SwarmAgents"
# Ensure all agents in hand-off after-works are in the passed in agents list
for agent in agents:
if agent.after_work is not None:
if isinstance(agent.after_work.agent, SwarmAgent):
assert agent.after_work.agent in agents, "Agent in hand-off must be in the agents list"

context_variables = context_variables or {}
if isinstance(messages, str):
Expand Down Expand Up @@ -175,9 +184,12 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
last_message = messages[0]

if "name" in last_message:
if "name" in swarm_agent_names:
if last_message["name"] in swarm_agent_names:
# If there's a name in the message and it's a swarm agent, use that
last_agent = groupchat.agent_by_name(name=last_message["name"])
elif user_agent and last_message["name"] == user_agent.name:
# If the user agent is passed in and is the first message
last_agent = user_agent
else:
raise ValueError(f"Invalid swarm agent name in last message: {last_message['name']}")
else:
Expand Down Expand Up @@ -260,9 +272,13 @@ def __init__(
)

if isinstance(functions, list):
if not all(isinstance(func, Callable) for func in functions):
raise TypeError("All elements in the functions list must be callable")
self.add_functions(functions)
elif isinstance(functions, Callable):
self.add_single_function(functions)
elif functions is not None:
raise TypeError("Functions must be a callable or a list of callables")

self.after_work = None

Expand Down Expand Up @@ -299,11 +315,18 @@ def transfer_to_agent_name() -> SwarmAgent:
1. register the function with the agent
2. register the schema with the agent, description set to the condition
"""
# Ensure that hand_to is a list or ON_CONDITION or AFTER_WORK
if not isinstance(hand_to, (list, ON_CONDITION, AFTER_WORK)):
raise ValueError("hand_to must be a list of ON_CONDITION or AFTER_WORK")

if isinstance(hand_to, (ON_CONDITION, AFTER_WORK)):
hand_to = [hand_to]

for transit in hand_to:
if isinstance(transit, AFTER_WORK):
assert isinstance(
transit.agent, (AfterWorkOption, SwarmAgent, str, Callable)
), "Invalid After Work value"
self.after_work = transit
elif isinstance(transit, ON_CONDITION):

Expand Down Expand Up @@ -340,8 +363,18 @@ def generate_swarm_tool_reply(

message = messages[-1]
if "tool_calls" in message:
# 1. add context_variables to the tool call arguments
for tool_call in message["tool_calls"]:

tool_calls = len(message["tool_calls"])

# Loop through tool calls individually (so context can be updated after each function call)
next_agent = None
tool_responses_inner = []
contents = []
for index in range(tool_calls):

# 1. add context_variables to the tool call arguments
tool_call = message["tool_calls"][index]

marklysze marked this conversation as resolved.
Show resolved Hide resolved
if tool_call["type"] == "function":
function_name = tool_call["function"]["name"]

Expand All @@ -357,20 +390,36 @@ def generate_swarm_tool_reply(
# Update the tool call with new arguments
tool_call["function"]["arguments"] = json.dumps(current_args)

# 2. generate tool calls reply
_, tool_message = self.generate_tool_calls_reply([message])

# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")
if isinstance(content, SwarmResult):
if content.context_variables != {}:
self._context_variables.update(content.context_variables)
if content.agent is not None:
self._next_agent = content.agent
elif isinstance(content, Agent):
self._next_agent = content
tool_response["content"] = str(tool_response["content"])
# Copy the message
message_copy = message.copy()
tool_calls_copy = message_copy["tool_calls"]

# remove all the tool calls except the one at the index
message_copy["tool_calls"] = [tool_calls_copy[index]]

# 2. generate tool calls reply
_, tool_message = self.generate_tool_calls_reply([message_copy])

# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")
if isinstance(content, SwarmResult):
if content.context_variables != {}:
self._context_variables.update(content.context_variables)
if content.agent is not None:
next_agent = content.agent
elif isinstance(content, Agent):
next_agent = content

tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))

self._next_agent = next_agent

# Put the tool responses and content strings back into the response message
# Caters for multiple tool calls
tool_message["tool_responses"] = tool_responses_inner
tool_message["content"] = "\n".join(contents)

return True, tool_message
return False, None
Expand Down
Loading
Loading