Skip to content

Commit

Permalink
Updated AFTER_WORK Callable to allow an AfterWorkOption to be returne…
Browse files Browse the repository at this point in the history
…d as well

Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 12, 2024
1 parent b5d3d98 commit e0abaca
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def initiate_swarm_chat(
- REVERT_TO_USER : Revert to the user agent if a user agent is provided. If not provided, terminate the conversation.
- STAY : Stay with the last speaker.
Callable: A custom function that takes the current agent, messages, groupchat, and context_variables as arguments and returns the next agent. The function should return None to terminate.
Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a SwarmAgent.
```python
def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat, context_variables: Optional[Dict[str, Any]]) -> Optional[SwarmAgent]:
def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent]:
```
Returns:
ChatResult: Conversations chat history.
Expand Down Expand Up @@ -166,28 +166,36 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
if (user_agent and last_speaker == user_agent) or groupchat.messages[-1]["role"] == "tool":
return last_swarm_speaker

# No agent selected via hand-offs (tool calls)
# Assume the work is Done
# override if agent-level after_work is defined, else use the global after_work
tmp_after_work = last_swarm_speaker.after_work if last_swarm_speaker.after_work is not None else after_work
if isinstance(tmp_after_work, AFTER_WORK):
tmp_after_work = tmp_after_work.agent

if isinstance(tmp_after_work, SwarmAgent):
return tmp_after_work
elif isinstance(tmp_after_work, AfterWorkOption):
if tmp_after_work == AfterWorkOption.TERMINATE or (
user_agent is None and tmp_after_work == AfterWorkOption.REVERT_TO_USER
):
return None
elif tmp_after_work == AfterWorkOption.REVERT_TO_USER:
return user_agent
elif tmp_after_work == AfterWorkOption.STAY:
return last_speaker
elif isinstance(tmp_after_work, Callable):
return tmp_after_work(last_speaker, groupchat.messages, groupchat, context_variables)
# Resolve after_work condition (agent-level overrides global)
after_work_condition = (
last_swarm_speaker.after_work if last_swarm_speaker.after_work is not None else after_work
)
if isinstance(after_work_condition, AFTER_WORK):
after_work_condition = after_work_condition.agent

# Evaluate callable after_work
if isinstance(after_work_condition, Callable):
after_work_condition = after_work_condition(last_speaker, groupchat.messages, groupchat)

if isinstance(after_work_condition, str): # Agent name in a string
if after_work_condition in swarm_agent_names:
after_work_condition = groupchat.agent_by_name(name=after_work_condition)
else:
raise ValueError(f"Invalid agent name in after_work: {after_work_condition}")

# Determine next action based on after_work_condition
if isinstance(after_work_condition, SwarmAgent):
return after_work_condition
elif after_work_condition == AfterWorkOption.TERMINATE or (
user_agent is None and after_work_condition == AfterWorkOption.REVERT_TO_USER
):
return None
elif after_work_condition == AfterWorkOption.REVERT_TO_USER:
return user_agent
elif after_work_condition == AfterWorkOption.STAY:
return last_speaker
else:
raise ValueError("Invalid After Work condition")
raise ValueError("Invalid After Work condition or return value from callable")

def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]):
"""Create nested chat agents and register nested chats"""
Expand Down Expand Up @@ -461,8 +469,6 @@ def _update_conditional_functions(self):
else:
is_available = True

print(f"DEBUG INFO: Function {func_name} available? {'Yes' if is_available else 'No'}")

if is_available:
if func_name not in self._function_map:
self.add_single_function(func, func_name, on_condition.condition)
Expand All @@ -471,7 +477,6 @@ def _update_conditional_functions(self):
if func_name in self._function_map:
self.update_tool_signature(func_name, is_remove=True)
del self._function_map[func_name]
print(f"Removed function: {func_name} from {self.name}")

def generate_swarm_tool_reply(
self,
Expand Down

0 comments on commit e0abaca

Please sign in to comment.