From e97b6395affe1bfea4eefb799ace4fecbf1b9a94 Mon Sep 17 00:00:00 2001 From: bitnom <14287229+bitnom@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:46:20 -0500 Subject: [PATCH] Allow initiate_chat without passing message (#1244) * allow initiate_chat without passing message * test human input * assert called * Add missing method a_generate_init_message * fix tests * add back skipif * Update test/agentchat/test_async_get_human_input.py --------- Co-authored-by: Chi Wang --- autogen/agentchat/conversable_agent.py | 21 ++++++++- test/agentchat/test_async_get_human_input.py | 21 +++++---- test/agentchat/test_human_input.py | 46 ++++++++++++++++++++ 3 files changed, 76 insertions(+), 12 deletions(-) create mode 100644 test/agentchat/test_human_input.py diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 33898dc976ff..e96cf36953e6 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -679,6 +679,7 @@ def initiate_chat( silent (bool or None): (Experimental) whether to print the messages for this conversation. **context: any context information. "message" needs to be provided if the `generate_init_message` method is not overridden. + Otherwise, input() will be called to get the initial message. Raises: RuntimeError: if any async reply functions are registered and not ignored in sync chat. @@ -707,9 +708,10 @@ async def a_initiate_chat( silent (bool or None): (Experimental) whether to print the messages for this conversation. **context: any context information. "message" needs to be provided if the `generate_init_message` method is not overridden. + Otherwise, input() will be called to get the initial message. """ self._prepare_chat(recipient, clear_history) - await self.a_send(self.generate_init_message(**context), recipient, silent=silent) + await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent) def reset(self): """Reset the agent.""" @@ -1583,7 +1585,24 @@ def generate_init_message(self, **context) -> Union[str, Dict]: Args: **context: any context information, and "message" parameter needs to be provided. + If message is not given, prompt for it via input() """ + if "message" not in context: + context["message"] = self.get_human_input(">") + return context["message"] + + async def a_generate_init_message(self, **context) -> Union[str, Dict]: + """Generate the initial message for the agent. + + Override this function to customize the initial message based on user's request. + If not overridden, "message" needs to be provided in the context. + + Args: + **context: any context information, and "message" parameter needs to be provided. + If message is not given, prompt for it via input() + """ + if "message" not in context: + context["message"] = await self.a_get_human_input(">") return context["message"] def register_function(self, function_map: Dict[str, Callable]): diff --git a/test/agentchat/test_async_get_human_input.py b/test/agentchat/test_async_get_human_input.py index 1285696c03e4..7af4237fc86f 100644 --- a/test/agentchat/test_async_get_human_input.py +++ b/test/agentchat/test_async_get_human_input.py @@ -1,9 +1,11 @@ import asyncio +import os +import sys +from unittest.mock import AsyncMock + import autogen import pytest from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST -import sys -import os sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from conftest import skip_openai # noqa: E402 @@ -25,20 +27,17 @@ async def test_async_get_human_input(): assistant = autogen.AssistantAgent( name="assistant", max_consecutive_auto_reply=2, - llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0}, + llm_config={"seed": 41, "config_list": config_list, "temperature": 0}, ) user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False) - async def custom_a_get_human_input(prompt): - return "This is a test" - - user_proxy.a_get_human_input = custom_a_get_human_input + user_proxy.a_get_human_input = AsyncMock(return_value="This is a test") user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply) await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.") - - -if __name__ == "__main__": - test_async_get_human_input() + # Test without message + await user_proxy.a_initiate_chat(assistant, clear_history=True) + # Assert that custom a_get_human_input was called at least once + user_proxy.a_get_human_input.assert_called() diff --git a/test/agentchat/test_human_input.py b/test/agentchat/test_human_input.py new file mode 100644 index 000000000000..837044feaaeb --- /dev/null +++ b/test/agentchat/test_human_input.py @@ -0,0 +1,46 @@ +import autogen +import pytest +from unittest.mock import MagicMock +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from conftest import skip_openai # noqa: E402 + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False or skip_openai + + +@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") +def test_get_human_input(): + config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC) + + # create an AssistantAgent instance named "assistant" + assistant = autogen.AssistantAgent( + name="assistant", + max_consecutive_auto_reply=2, + llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0}, + ) + + user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False) + + # Use MagicMock to create a mock get_human_input function + user_proxy.get_human_input = MagicMock(return_value="This is a test") + + user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply) + + user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.") + # Test without supplying messages parameter + user_proxy.initiate_chat(assistant, clear_history=True) + + # Assert that custom_a_get_human_input was called at least once + user_proxy.get_human_input.assert_called() + + +if __name__ == "__main__": + test_get_human_input()