Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Apr 26, 2024
1 parent 0bca6f6 commit 3afc5a5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
24 changes: 14 additions & 10 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def __init__(self,
self._init_threads()

def get_completion(self, message: str,
message_files=None,
yield_messages=False,
recipient_agent=None,
additional_instructions=None,
message_files: List[str] = None,
yield_messages: bool = False,
recipient_agent: Agent = None,
additional_instructions: str = None,
attachments: List[dict] = None,
tool_choice: dict = None,
):
Expand All @@ -127,6 +127,7 @@ def get_completion(self, message: str,
recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart.
additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None.
attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None.
tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None.
Returns:
Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread.
Expand All @@ -144,11 +145,12 @@ def get_completion(self, message: str,
def get_completion_stream(self,
message: str,
event_handler: type(AgencyEventHandler),
message_files=None,
recipient_agent=None,
message_files: List[str] = None,
recipient_agent: Agent = None,
additional_instructions: str = None,
attachments: List[dict] = None,
tool_choice: dict = None):
tool_choice: dict = None
):
"""
Generates a stream of completions for a given message from the main thread.
Expand All @@ -159,6 +161,8 @@ def get_completion_stream(self,
recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart.
additional_instructions (str, optional): Additional instructions to be sent with the message. Defaults to None.
attachments (List[dict], optional): A list of attachments to be sent with the message, following openai format. Defaults to None.
tool_choice (dict, optional): The tool choice for the recipient agent to use. Defaults to None.
Returns:
Final response: Final response from the main thread.
"""
Expand Down Expand Up @@ -815,8 +819,8 @@ class SendMessage(BaseTool):
"clarifying what the task entails, rather than providing exact "
"instructions.")
message_files: Optional[List[str]] = Field(default=None,
description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.",
examples=["file-1234", "file-5678"])
description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.",
examples=["file-1234", "file-5678"])
additional_instructions: str = Field(default=None,
description="Any additional instructions or clarifications that you would like to provide to the recipient agent.")
one_call_at_a_time: bool = True
Expand Down Expand Up @@ -948,4 +952,4 @@ def delete(self):
This method deletes the agency and all its agents, cleaning up any files and vector stores associated with each agent.
"""
for agent in self.agents:
agent.delete()
agent.delete()
17 changes: 9 additions & 8 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,11 @@ def test_7_init_async_agency(self):
# reset loaded thread ids
self.__class__.loaded_thread_ids = {}

self.__class__.agent1.instructions = "Your task is to say 'success' and nothing else."

self.__class__.agency = Agency([
self.__class__.ceo,
[self.__class__.ceo, self.__class__.agent1]],
shared_instructions="This is a shared instruction",
[self.__class__.ceo, self.__class__.agent1],
[self.__class__.agent1, self.__class__.agent2]],
shared_instructions="",
settings_callbacks=self.__class__.settings_callbacks,
threads_callbacks=self.__class__.threads_callbacks,
async_mode='threading',
Expand All @@ -374,14 +373,16 @@ def test_7_init_async_agency(self):
def test_8_async_agent_communication(self):
"""it should communicate between agents asynchronously"""
print("TestAgent1 tools", self.__class__.agent1.tools)
self.__class__.agency.get_completion("Please tell TestAgent1 hello.",
tool_choice={"type": "function", "function": {"name": "SendMessage"}})
self.__class__.agency.get_completion("Please tell TestAgent2 hello.",
tool_choice={"type": "function", "function": {"name": "SendMessage"}},
recipient_agent=self.__class__.agent1)

time.sleep(10)

message = self.__class__.agency.get_completion(
"Please check response. If output includes `TestAgent1's Response`, say 'success'. If the function output does not include `TestAgent1's Response`, or if you get a System Notification, or an error instead, say 'error'.",
tool_choice={"type": "function", "function": {"name": "GetResponse"}})
"Please check response. If output includes `TestAgent2's Response`, say 'success'. If the function output does not include `TestAgent2's Response`, or if you get a System Notification, or an error instead, say 'error'.",
tool_choice={"type": "function", "function": {"name": "GetResponse"}},
recipient_agent=self.__class__.agent1)

if 'error' in message.lower():
print(self.__class__.agency.get_completion("Explain why you said error."))
Expand Down

0 comments on commit 3afc5a5

Please sign in to comment.