Skip to content

Commit

Permalink
allow function to remove termination string in groupchat (microsoft#2804
Browse files Browse the repository at this point in the history
)

* allow function to remove termination string in groupchat

* improve docstring

Co-authored-by: Joshua Kim <[email protected]>

* improve docstring

Co-authored-by: Joshua Kim <[email protected]>

* improve test case description

Co-authored-by: Joshua Kim <[email protected]>

---------

Co-authored-by: Joshua Kim <[email protected]>
  • Loading branch information
aswny and joshkyh authored Jun 6, 2024
1 parent 8564bd4 commit 75f0808
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 11 deletions.
34 changes: 25 additions & 9 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,15 +1160,17 @@ async def a_run_chat(
def resume(
self,
messages: Union[List[Dict], str],
remove_termination_string: str = None,
remove_termination_string: Union[str, Callable[[str], str]] = None,
silent: Optional[bool] = False,
) -> Tuple[ConversableAgent, Dict]:
"""Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
as per the original group chat.
Args:
- messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
- remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
- remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
If a string is provided, this string will be removed from last message.
If a function is provided, the last message will be passed to this function.
- silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
Returns:
Expand Down Expand Up @@ -1263,15 +1265,17 @@ def resume(
async def a_resume(
self,
messages: Union[List[Dict], str],
remove_termination_string: str = None,
remove_termination_string: Union[str, Callable[[str], str]],
silent: Optional[bool] = False,
) -> Tuple[ConversableAgent, Dict]:
"""Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
as per the original group chat.
Args:
- messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
- remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
- remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
If a string is provided, this string will be removed from last message.
If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
- silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
Returns:
Expand Down Expand Up @@ -1390,11 +1394,15 @@ def _valid_resume_messages(self, messages: List[Dict]):
):
raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")

def _process_resume_termination(self, remove_termination_string: str, messages: List[Dict]):
def _process_resume_termination(
self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict]
):
"""Removes termination string, if required, and checks if termination may occur.
args:
remove_termination_string (str): termination string to remove from the last message
remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
If a string is provided, this string will be removed from last message.
If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
returns:
None
Expand All @@ -1403,9 +1411,17 @@ def _process_resume_termination(self, remove_termination_string: str, messages:
last_message = messages[-1]

# Replace any given termination string in the last message
if remove_termination_string:
if messages[-1].get("content") and remove_termination_string in messages[-1]["content"]:
messages[-1]["content"] = messages[-1]["content"].replace(remove_termination_string, "")
if isinstance(remove_termination_string, str):

def _remove_termination_string(content: str) -> str:
return content.replace(remove_termination_string, "")

else:
_remove_termination_string = remove_termination_string

if _remove_termination_string:
if messages[-1].get("content"):
messages[-1]["content"] = _remove_termination_string(messages[-1]["content"])

# Check if the last message meets termination (if it has one)
if self._is_termination_msg:
Expand Down
49 changes: 47 additions & 2 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,51 @@ def test_manager_resume_functions():
# TERMINATE should be removed
assert messages[-1]["content"] == final_msg.replace("TERMINATE", "")

# Tests termination message replacement with function
def termination_func(x: str) -> str:
if "APPROVED" in x:
x = x.replace("APPROVED", "")
else:
x = x.replace("TERMINATE", "")
return x

final_msg1 = "Product_Manager has created 3 new product ideas. APPROVED"
messages1 = [
{
"content": "You are an expert at finding the next speaker.",
"role": "system",
},
{
"content": final_msg1,
"name": "Coder",
"role": "assistant",
},
]

manager._process_resume_termination(remove_termination_string=termination_func, messages=messages1)

# APPROVED should be removed
assert messages1[-1]["content"] == final_msg1.replace("APPROVED", "")

final_msg2 = "Idea has been approved. TERMINATE"
messages2 = [
{
"content": "You are an expert at finding the next speaker.",
"role": "system",
},
{
"content": final_msg2,
"name": "Coder",
"role": "assistant",
},
]

manager._process_resume_termination(remove_termination_string=termination_func, messages=messages2)

# TERMINATE should be removed, "approved" should still be present as the termination_func only replaces upper-cased "APPROVED".
assert messages2[-1]["content"] == final_msg2.replace("TERMINATE", "")
assert "approved" in messages2[-1]["content"]

# Check if the termination string doesn't exist there's no replacing of content
final_msg = (
"Let's get this meeting started. First the Product_Manager will create 3 new product ideas. TERMINATE this."
Expand Down Expand Up @@ -2027,7 +2072,7 @@ def test_manager_resume_messages():
# test_clear_agents_history()
# test_custom_speaker_selection_overrides_transition_graph()
# test_role_for_select_speaker_messages()
test_select_speaker_message_and_prompt_templates()
# test_select_speaker_message_and_prompt_templates()
# test_speaker_selection_agent_name_match()
# test_role_for_reflection_summary()
# test_speaker_selection_auto_process_result()
Expand All @@ -2036,7 +2081,7 @@ def test_manager_resume_messages():
# test_select_speaker_auto_messages()
# test_manager_messages_to_string()
# test_manager_messages_from_string()
# test_manager_resume_functions()
test_manager_resume_functions()
# test_manager_resume_returns()
# test_manager_resume_messages()
pass

0 comments on commit 75f0808

Please sign in to comment.