diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 4e5001d6..4cb88385 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -26,9 +26,14 @@ from agency_swarm.threads import Thread from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch from agency_swarm.user import User -from agency_swarm.util.errors import RefusalError -from agency_swarm.util.files import get_tools, get_file_purpose +from agency_swarm.util.files import determine_file_type, get_tools, get_file_purpose +from agency_swarm.util.helpers import extract_zip, extract_tar, git_clone +from agency_swarm.util.helpers.file_upload_helpers import is_file_extension_supported from agency_swarm.util.shared_state import SharedState +from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, \ + FileSearchToolCall +from agency_swarm.util.errors import RefusalError + from agency_swarm.util.streaming import AgencyEventHandler console = Console() @@ -156,14 +161,13 @@ def get_completion(self, message: str, raise Exception("Verbose mode is not compatible with yield_messages=True") res = self.main_thread.get_completion(message=message, - message_files=message_files, - attachments=attachments, - recipient_agent=recipient_agent, - additional_instructions=additional_instructions, - tool_choice=tool_choice, - yield_messages=yield_messages or verbose, - response_format=response_format) - + message_files=message_files, + attachments=attachments, + recipient_agent=recipient_agent, + additional_instructions=additional_instructions, + tool_choice=tool_choice, + yield_messages=yield_messages or verbose, + response_format=response_format) if not yield_messages or verbose: while True: try: @@ -175,7 +179,6 @@ def get_completion(self, message: str, return res - def get_completion_stream(self, message: str, event_handler: type(AgencyEventHandler), @@ -205,14 +208,14 @@ def get_completion_stream(self, raise Exception("Event handler must not be an instance.") res = self.main_thread.get_completion_stream(message=message, - message_files=message_files, - event_handler=event_handler, - attachments=attachments, - recipient_agent=recipient_agent, - additional_instructions=additional_instructions, - tool_choice=tool_choice, - response_format=response_format) - + message_files=message_files, + event_handler=event_handler, + attachments=attachments, + recipient_agent=recipient_agent, + additional_instructions=additional_instructions, + tool_choice=tool_choice, + response_format=response_format + ) while True: try: next(res) @@ -300,8 +303,16 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): images = [] message_file_names = None uploading_files = False + cloning_files = False recipient_agents = [agent.name for agent in self.main_recipients] - recipient_agent = self.main_recipients[0] + recipient_agent: Agent = self.main_recipients[0] + + if isinstance(recipient_agent.files_folder, list): + to_folder = recipient_agent.files_folder[0] + else: + to_folder = recipient_agent.files_folder + + os.makedirs(to_folder, exist_ok=True) with gr.Blocks(js=js) as demo: chatbot_queue = queue.Queue() @@ -310,15 +321,76 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): with gr.Column(scale=9): dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, value=recipient_agent.name) - msg = gr.Textbox(label="Your Message", lines=4) + msg = gr.Textbox(label="Your Message", lines=10) with gr.Column(scale=1): file_upload = gr.Files(label="OpenAI Files", type="filepath") + repo_clone = gr.Textbox(label="OpenAI GIT URL", lines=1) button = gr.Button(value="Send", variant="primary") def handle_dropdown_change(selected_option): nonlocal recipient_agent recipient_agent = self._get_agent_by_name(selected_option) + def handle_file_clone(repo_url): + nonlocal attachments + nonlocal message_file_names + nonlocal cloning_files + nonlocal images + cloning_files = True + attachments = [] + message_file_names = [] + + try: + extracted_files = [] + for file_path in git_clone(repo_url, to_folder): + if file_path.endswith('.zip'): + extracted_files.extend(extract_zip(file_path, to_folder)) + elif file_path.endswith('.tar.gz'): + extracted_files.extend(extract_tar(file_path, to_folder)) + else: + extracted_files.append(file_path) + + print(f"Found {', '.join(extracted_files)}") + + for file in extracted_files: + if is_file_extension_supported(file): + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") + return attachments + except Exception as e: + print(f"Error: {e}") + finally: + cloning_files = False + cloning_files = False + return "No files uploaded" + def handle_file_upload(file_list): nonlocal attachments nonlocal message_file_names @@ -329,7 +401,9 @@ def handle_file_upload(file_list): message_file_names = [] if file_list: try: + extracted_files = [] for file_obj in file_list: + file_path = file_obj.name purpose = get_file_purpose(file_obj.name) with open(file_obj.name, 'rb') as f: @@ -349,25 +423,62 @@ def handle_file_upload(file_list): "file_id": file.id, "tools": get_tools(file.filename) }) - - message_file_names.append(file.filename) - print(f"Uploaded file ID: {file.id}") + + if file_path.endswith('.zip'): + extracted_files.extend(extract_zip(file_path, to_folder)) + elif file_path.endswith('.tar.gz'): + extracted_files.extend(extract_tar(file_path, to_folder)) + else: + extracted_files.append(file_path) + + print(f"Found {', '.join(extracted_files)}") + + for file in extracted_files: + if is_file_extension_supported(file): + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") return attachments except Exception as e: print(f"Error: {e}") - return str(e) finally: uploading_files = False - uploading_files = False return "No files uploaded" def user(user_message, history): if not user_message.strip(): return user_message, history - + nonlocal message_file_names nonlocal uploading_files + nonlocal cloning_files nonlocal images nonlocal attachments nonlocal recipient_agent @@ -380,13 +491,15 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent): if not any(isinstance(t, FileSearch) for t in recipient_agent.tools): # Add FileSearch tool if it does not exist recipient_agent.tools.append(FileSearch) - recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools()) + recipient_agent.client.beta.assistants.update(recipient_agent.id, + tools=recipient_agent.get_oai_tools()) print("Added FileSearch tool to recipient agent to analyze the file.") elif tool["type"] == "code_interpreter": if not any(isinstance(t, CodeInterpreter) for t in recipient_agent.tools): # Add CodeInterpreter tool if it does not exist recipient_agent.tools.append(CodeInterpreter) - recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools()) + recipient_agent.client.beta.assistants.update(recipient_agent.id, + tools=recipient_agent.get_oai_tools()) print("Added CodeInterpreter tool to recipient agent to analyze the file.") return None @@ -429,7 +542,6 @@ def on_message_created(self, message: Message) -> None: if content.type == "text": full_content += content.text.value + "\n" - self.message_output = MessageOutput("text", self.agent_name, self.recipient_agent_name, full_content) @@ -449,7 +561,7 @@ def on_tool_call_created(self, tool_call: ToolCall): if isinstance(tool_call, dict): if "type" not in tool_call: tool_call["type"] = "function" - + if tool_call["type"] == "function": tool_call = FunctionToolCall(**tool_call) elif tool_call["type"] == "code_interpreter": @@ -471,7 +583,7 @@ def on_tool_call_done(self, snapshot: ToolCall): if isinstance(snapshot, dict): if "type" not in snapshot: snapshot["type"] = "function" - + if snapshot["type"] == "function": snapshot = FunctionToolCall(**snapshot) elif snapshot["type"] == "code_interpreter": @@ -480,7 +592,7 @@ def on_tool_call_done(self, snapshot: ToolCall): snapshot = FileSearchToolCall(**snapshot) else: raise ValueError("Invalid tool call type: " + snapshot["type"]) - + self.message_output = None # TODO: add support for code interpreter and retrieval tools @@ -538,15 +650,21 @@ def bot(original_message, history): nonlocal recipient_agent nonlocal images nonlocal uploading_files + nonlocal cloning_files if uploading_files: history.append([None, "Uploading files... Please wait."]) yield "", history return "", history + if cloning_files: + history.append([None, "Cloning files... Please wait."]) + yield "", history + return "", history + print("Message files: ", attachments) print("Images: ", images) - + if images and len(images) > 0: original_message = [ { @@ -556,7 +674,6 @@ def bot(original_message, history): *images ] - completion_thread = threading.Thread(target=self.get_completion_stream, args=( original_message, GradioEventHandler, [], recipient_agent, "", attachments, None)) completion_thread.start() @@ -565,7 +682,7 @@ def bot(original_message, history): message_file_names = [] images = [] uploading_files = False - + cloning_files = False new_message = True while True: try: @@ -598,6 +715,7 @@ def bot(original_message, history): ) dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) + repo_clone.change(handle_file_clone, repo_clone) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [msg, chatbot], [msg, chatbot] ) @@ -673,7 +791,7 @@ def on_tool_call_created(self, tool_call): if isinstance(tool_call, dict): if "type" not in tool_call: tool_call["type"] = "function" - + if tool_call["type"] == "function": tool_call = FunctionToolCall(**tool_call) elif tool_call["type"] == "code_interpreter": @@ -694,7 +812,7 @@ def on_tool_call_delta(self, delta, snapshot): if isinstance(snapshot, dict): if "type" not in snapshot: snapshot["type"] = "function" - + if snapshot["type"] == "function": snapshot = FunctionToolCall(**snapshot) elif snapshot["type"] == "code_interpreter": @@ -703,7 +821,7 @@ def on_tool_call_delta(self, delta, snapshot): snapshot = FileSearchToolCall(**snapshot) else: raise ValueError("Invalid tool call type: " + snapshot["type"]) - + self.message_output.cprint_update(str(snapshot.function)) @override @@ -830,7 +948,7 @@ def _init_agents(self): agent.max_completion_tokens = self.max_completion_tokens if self.truncation_strategy is not None and agent.truncation_strategy is None: agent.truncation_strategy = self.truncation_strategy - + if not agent.shared_state: agent.shared_state = self.shared_state diff --git a/agency_swarm/util/helpers/__init__.py b/agency_swarm/util/helpers/__init__.py index 607b52b1..8027a241 100644 --- a/agency_swarm/util/helpers/__init__.py +++ b/agency_swarm/util/helpers/__init__.py @@ -1,2 +1,5 @@ from .get_available_agent_descriptions import get_available_agent_descriptions -from .list_available_agents import list_available_agents \ No newline at end of file +from .list_available_agents import list_available_agents +from .file_upload_helpers import extract_tar +from .file_upload_helpers import extract_zip +from .file_upload_helpers import git_clone \ No newline at end of file diff --git a/agency_swarm/util/helpers/file_upload_helpers.py b/agency_swarm/util/helpers/file_upload_helpers.py new file mode 100644 index 00000000..91428e08 --- /dev/null +++ b/agency_swarm/util/helpers/file_upload_helpers.py @@ -0,0 +1,120 @@ +import os +import tarfile +import tempfile +import zipfile +from typing import List + +from git import Repo + +excluded_folders = ['node_modules', 'venv', 'vendor', '.git', '.idea'] +supported_extensions = [ + # Text Files + ".txt", ".csv", ".json", ".xml", + + # Spreadsheet Files + ".xls", ".xlsx", + + # Document Files + ".doc", ".docx", ".pdf", + + # Presentation Files + ".ppt", ".pptx", + + # Image Files + ".jpg", ".jpeg", ".png", ".gif", ".bmp", + + # Code Files + ".py", ".java", ".js", ".html", ".css", ".cpp", ".c", ".rb", ".php", + + # Audio Files + ".mp3", ".wav", + + # Video Files + ".mp4", ".avi", ".mkv" +] + + +def sanitize_file_name(file_path: str) -> str: + new_file_path = file_path + try: + if os.path.basename(file_path).index('.') < 1: + new_file_path += '.txt' + os.rename(file_path, new_file_path) + except ValueError as e: + new_file_path += '.txt' + os.rename(file_path, new_file_path) + return new_file_path + + +def is_file_extension_supported(file_path: str) -> bool: + for extension in supported_extensions: + if file_path.endswith(extension): + return True + + return False + + +def extract_zip(file_path: str, to_folder: str) -> List[str]: + """ + Extracts files from a zip archive. + + Parameters: + file_path (str): The path to the zip file. + to_folder (str): Path where should be files extracted + + Returns: + List[str]: A list of paths to the extracted files. + """ + extracted_files = [] + with zipfile.ZipFile(file_path, 'r') as zip_ref: + temp_dir = tempfile.mkdtemp('zip', 'extracted', to_folder) + zip_ref.extractall(temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + extracted_files.append(sanitize_file_name(os.path.join(root, file))) + return extracted_files + + +def git_clone(repo_url: str, to_folder: str) -> List[str]: + """ + Clones a Git repository to a specified directory. + + Parameters: + repo_url (str): The URL of the repository to clone. + to_folder (str): Path where should be git cloned into. + """ + cloned_files = [] + temp_dir = tempfile.mkdtemp('git', 'cloned', to_folder) + try: + Repo.clone_from(repo_url, temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + cloned_files.append(sanitize_file_name(os.path.join(root, file))) + except Exception as e: + print(f"Error cloning repository: {e}") + finally: + print(f"Repository cloned into {temp_dir}") + return cloned_files + + +def extract_tar(file_path: str, to_folder: str) -> List[str]: + """ + Extracts files from a tar.gz archive. + + Parameters: + file_path (str): The path to the tar.gz file. + + Returns: + List[str]: A list of paths to the extracted files. + """ + extracted_files = [] + with tarfile.open(file_path, 'r:gz') as tar_ref: + temp_dir = tempfile.mkdtemp('tar', 'extracted', to_folder) + tar_ref.extractall(temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + extracted_files.append(sanitize_file_name(os.path.join(root, file))) + return extracted_files diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index f2100dfe..34ff3068 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ deepdiff==6.7.1 termcolor==2.4.0 python-dotenv==1.0.1 rich==13.7.1 -jsonref==1.1.0 \ No newline at end of file +jsonref==1.1.0 +gitpython==3.1.41 \ No newline at end of file