diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index b5193c03..1f4a628c 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -110,10 +110,10 @@ def __init__(self, self._init_threads() def get_completion(self, message: str, - message_files=None, - yield_messages=False, - recipient_agent=None, - additional_instructions=None, + message_files: List[str] = None, + yield_messages: bool = False, + recipient_agent: Agent = None, + additional_instructions: str = None, attachments: List[dict] = None, tool_choice: dict = None, ): @@ -127,6 +127,7 @@ def get_completion(self, message: str, recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart. additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None. attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None. + tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None. Returns: Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread. @@ -144,11 +145,12 @@ def get_completion(self, message: str, def get_completion_stream(self, message: str, event_handler: type(AgencyEventHandler), - message_files=None, - recipient_agent=None, + message_files: List[str] = None, + recipient_agent: Agent = None, additional_instructions: str = None, attachments: List[dict] = None, - tool_choice: dict = None): + tool_choice: dict = None + ): """ Generates a stream of completions for a given message from the main thread. @@ -159,6 +161,8 @@ def get_completion_stream(self, recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart. additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None. attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None. + tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None. + Returns: Final response: Final response from the main thread. """ @@ -815,8 +819,8 @@ class SendMessage(BaseTool): "clarifying what the task entails, rather than providing exact " "instructions.") message_files: Optional[List[str]] = Field(default=None, - description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.", - examples=["file-1234", "file-5678"]) + description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.", + examples=["file-1234", "file-5678"]) additional_instructions: str = Field(default=None, description="Any additional instructions or clarifications that you would like to provide to the recipient agent.") one_call_at_a_time: bool = True @@ -948,4 +952,4 @@ def delete(self): This method deletes the agency and all its agents, cleaning up any files and vector stores associated with each agent. """ for agent in self.agents: - agent.delete() \ No newline at end of file + agent.delete() diff --git a/tests/test_agency.py b/tests/test_agency.py index 0eaa012b..67a56270 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -357,12 +357,11 @@ def test_7_init_async_agency(self): # reset loaded thread ids self.__class__.loaded_thread_ids = {} - self.__class__.agent1.instructions = "Your task is to say 'success' and nothing else." - self.__class__.agency = Agency([ self.__class__.ceo, - [self.__class__.ceo, self.__class__.agent1]], - shared_instructions="This is a shared instruction", + [self.__class__.ceo, self.__class__.agent1], + [self.__class__.agent1, self.__class__.agent2]], + shared_instructions="", settings_callbacks=self.__class__.settings_callbacks, threads_callbacks=self.__class__.threads_callbacks, async_mode='threading', @@ -374,14 +373,16 @@ def test_7_init_async_agency(self): def test_8_async_agent_communication(self): """it should communicate between agents asynchronously""" print("TestAgent1 tools", self.__class__.agent1.tools) - self.__class__.agency.get_completion("Please tell TestAgent1 hello.", - tool_choice={"type": "function", "function": {"name": "SendMessage"}}) + self.__class__.agency.get_completion("Please tell TestAgent2 hello.", + tool_choice={"type": "function", "function": {"name": "SendMessage"}}, + recipient_agent=self.__class__.agent1) time.sleep(10) message = self.__class__.agency.get_completion( - "Please check response. If output includes `TestAgent1's Response`, say 'success'. If the function output does not include `TestAgent1's Response`, or if you get a System Notification, or an error instead, say 'error'.", - tool_choice={"type": "function", "function": {"name": "GetResponse"}}) + "Please check response. If output includes `TestAgent2's Response`, say 'success'. If the function output does not include `TestAgent2's Response`, or if you get a System Notification, or an error instead, say 'error'.", + tool_choice={"type": "function", "function": {"name": "GetResponse"}}, + recipient_agent=self.__class__.agent1) if 'error' in message.lower(): print(self.__class__.agency.get_completion("Explain why you said error."))