From 50fad64acd888e03b227478add5e4a7ed865e42f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juraj=20Puchk=C3=BD?= Date: Wed, 5 Jun 2024 21:47:56 +0200 Subject: [PATCH 1/4] feat:Add support zip, tar.gz and git repo uploading as files --- agency_swarm/agency/agency.py | 193 +++++++++++++----- agency_swarm/util/helpers/__init__.py | 5 +- .../util/helpers/file_upload_helpers.py | 127 ++++++++++++ requirements.txt | 3 +- 4 files changed, 272 insertions(+), 56 deletions(-) create mode 100644 agency_swarm/util/helpers/file_upload_helpers.py mode change 100644 => 100755 requirements.txt diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 47d9062d..eda991e3 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -21,9 +21,10 @@ from agency_swarm.tools import BaseTool, FileSearch, CodeInterpreter from agency_swarm.user import User from agency_swarm.util.files import determine_file_type +from agency_swarm.util.helpers import extract_zip, extract_tar, git_clone from agency_swarm.util.shared_state import SharedState -from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, FileSearchToolCall - +from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, \ + FileSearchToolCall from agency_swarm.util.streaming import AgencyEventHandler @@ -138,12 +139,12 @@ def get_completion(self, message: str, Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread. """ res = self.main_thread.get_completion(message=message, - message_files=message_files, - attachments=attachments, - recipient_agent=recipient_agent, - additional_instructions=additional_instructions, - tool_choice=tool_choice, - yield_messages=yield_messages) + message_files=message_files, + attachments=attachments, + recipient_agent=recipient_agent, + additional_instructions=additional_instructions, + tool_choice=tool_choice, + yield_messages=yield_messages) if not yield_messages: while True: @@ -154,7 +155,6 @@ def get_completion(self, message: str, return res - def get_completion_stream(self, message: str, event_handler: type(AgencyEventHandler), @@ -183,13 +183,13 @@ def get_completion_stream(self, raise Exception("Event handler must not be an instance.") res = self.main_thread.get_completion_stream(message=message, - message_files=message_files, - event_handler=event_handler, - attachments=attachments, - recipient_agent=recipient_agent, - additional_instructions=additional_instructions, - tool_choice=tool_choice - ) + message_files=message_files, + event_handler=event_handler, + attachments=attachments, + recipient_agent=recipient_agent, + additional_instructions=additional_instructions, + tool_choice=tool_choice + ) while True: try: @@ -230,6 +230,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): images = [] message_file_names = None uploading_files = False + cloning_files = False recipient_agents = [agent.name for agent in self.main_recipients] recipient_agent = self.main_recipients[0] @@ -240,15 +241,75 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): with gr.Column(scale=9): dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, value=recipient_agent.name) - msg = gr.Textbox(label="Your Message", lines=4) + msg = gr.Textbox(label="Your Message", lines=10) with gr.Column(scale=1): file_upload = gr.Files(label="OpenAI Files", type="filepath") + repo_clone = gr.Textbox(label="OpenAI GIT URL", lines=1) button = gr.Button(value="Send", variant="primary") def handle_dropdown_change(selected_option): nonlocal recipient_agent recipient_agent = self._get_agent_by_name(selected_option) + def handle_file_clone(repo_url): + nonlocal attachments + nonlocal message_file_names + nonlocal cloning_files + nonlocal images + cloning_files = True + attachments = [] + message_file_names = [] + + try: + extracted_files = [] + for file_path in git_clone(repo_url): + if file_path.endswith('.zip'): + extracted_files.extend(extract_zip(file_path)) + elif file_path.endswith('.tar.gz'): + extracted_files.extend(extract_tar(file_path)) + else: + extracted_files.append(file_path) + + print(f"Found {', '.join(extracted_files)}") + + for file in extracted_files: + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") + return attachments + except Exception as e: + print(f"Error: {e}") + finally: + cloning_files = False + cloning_files = False + return "No files uploaded" + def handle_file_upload(file_list): nonlocal attachments nonlocal message_file_names @@ -259,47 +320,64 @@ def handle_file_upload(file_list): message_file_names = [] if file_list: try: + extracted_files = [] for file_obj in file_list: - file_type = determine_file_type(file_obj.name) - purpose = "assistants" if file_type != "vision" else "vision" - tools = [{"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [{"type": "file_search"}] + file_path = file_obj.name - with open(file_obj.name, 'rb') as f: - # Upload the file to OpenAI - file = self.main_thread.client.files.create( - file=f, - purpose=purpose - ) - - if file_type == "vision": - images.append({ - "type": "image_file", - "image_file": {"file_id": file.id} - }) + if file_path.endswith('.zip'): + extracted_files.extend(extract_zip(file_path)) + elif file_path.endswith('.tar.gz'): + extracted_files.extend(extract_tar(file_path)) else: - attachments.append({ - "file_id": file.id, - "tools": tools - }) + extracted_files.append(file_path) + + print(f"Found {', '.join(extracted_files)}") - message_file_names.append(file.filename) - print(f"Uploaded file ID: {file.id}") + for file in extracted_files: + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") return attachments except Exception as e: print(f"Error: {e}") - return str(e) finally: uploading_files = False - uploading_files = False return "No files uploaded" def user(user_message, history): if not user_message.strip(): return user_message, history - + nonlocal message_file_names nonlocal uploading_files + nonlocal cloning_files nonlocal images nonlocal attachments nonlocal recipient_agent @@ -312,13 +390,15 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent): if not any(isinstance(t, FileSearch) for t in recipient_agent.tools): # Add FileSearch tool if it does not exist recipient_agent.tools.append(FileSearch) - recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools()) + recipient_agent.client.beta.assistants.update(recipient_agent.id, + tools=recipient_agent.get_oai_tools()) print("Added FileSearch tool to recipient agent to analyze the file.") elif tool["type"] == "code_interpreter": if not any(isinstance(t, CodeInterpreter) for t in recipient_agent.tools): # Add CodeInterpreter tool if it does not exist recipient_agent.tools.append(CodeInterpreter) - recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools()) + recipient_agent.client.beta.assistants.update(recipient_agent.id, + tools=recipient_agent.get_oai_tools()) print("Added CodeInterpreter tool to recipient agent to analyze the file.") return None @@ -361,7 +441,6 @@ def on_message_created(self, message: Message) -> None: if content.type == "text": full_content += content.text.value + "\n" - self.message_output = MessageOutput("text", self.agent_name, self.recipient_agent_name, full_content) @@ -381,7 +460,7 @@ def on_tool_call_created(self, tool_call: ToolCall): if isinstance(tool_call, dict): if "type" not in tool_call: tool_call["type"] = "function" - + if tool_call["type"] == "function": tool_call = FunctionToolCall(**tool_call) elif tool_call["type"] == "code_interpreter": @@ -403,7 +482,7 @@ def on_tool_call_done(self, snapshot: ToolCall): if isinstance(snapshot, dict): if "type" not in snapshot: snapshot["type"] = "function" - + if snapshot["type"] == "function": snapshot = FunctionToolCall(**snapshot) elif snapshot["type"] == "code_interpreter": @@ -412,7 +491,7 @@ def on_tool_call_done(self, snapshot: ToolCall): snapshot = FileSearchToolCall(**snapshot) else: raise ValueError("Invalid tool call type: " + snapshot["type"]) - + self.message_output = None # TODO: add support for code interpreter and retrieval tools @@ -470,15 +549,21 @@ def bot(original_message, history): nonlocal recipient_agent nonlocal images nonlocal uploading_files + nonlocal cloning_files if uploading_files: history.append([None, "Uploading files... Please wait."]) yield "", history return "", history + if cloning_files: + history.append([None, "Cloning files... Please wait."]) + yield "", history + return "", history + print("Message files: ", attachments) print("Images: ", images) - + if images and len(images) > 0: original_message = [ { @@ -488,7 +573,6 @@ def bot(original_message, history): *images ] - completion_thread = threading.Thread(target=self.get_completion_stream, args=( original_message, GradioEventHandler, [], recipient_agent, "", attachments, None)) completion_thread.start() @@ -497,7 +581,7 @@ def bot(original_message, history): message_file_names = [] images = [] uploading_files = False - + cloning_files = False new_message = True while True: try: @@ -530,6 +614,7 @@ def bot(original_message, history): ) dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) + repo_clone.change(handle_file_clone, repo_clone) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [msg, chatbot], [msg, chatbot] ) @@ -605,7 +690,7 @@ def on_tool_call_created(self, tool_call): if isinstance(tool_call, dict): if "type" not in tool_call: tool_call["type"] = "function" - + if tool_call["type"] == "function": tool_call = FunctionToolCall(**tool_call) elif tool_call["type"] == "code_interpreter": @@ -626,7 +711,7 @@ def on_tool_call_delta(self, delta, snapshot): if isinstance(snapshot, dict): if "type" not in snapshot: snapshot["type"] = "function" - + if snapshot["type"] == "function": snapshot = FunctionToolCall(**snapshot) elif snapshot["type"] == "code_interpreter": @@ -635,7 +720,7 @@ def on_tool_call_delta(self, delta, snapshot): snapshot = FileSearchToolCall(**snapshot) else: raise ValueError("Invalid tool call type: " + snapshot["type"]) - + self.message_output.cprint_update(str(snapshot.function)) @override @@ -762,7 +847,7 @@ def _init_agents(self): agent.max_completion_tokens = self.max_completion_tokens if self.truncation_strategy is not None and agent.truncation_strategy is None: agent.truncation_strategy = self.truncation_strategy - + if not agent.shared_state: agent.shared_state = self.shared_state diff --git a/agency_swarm/util/helpers/__init__.py b/agency_swarm/util/helpers/__init__.py index 607b52b1..8027a241 100644 --- a/agency_swarm/util/helpers/__init__.py +++ b/agency_swarm/util/helpers/__init__.py @@ -1,2 +1,5 @@ from .get_available_agent_descriptions import get_available_agent_descriptions -from .list_available_agents import list_available_agents \ No newline at end of file +from .list_available_agents import list_available_agents +from .file_upload_helpers import extract_tar +from .file_upload_helpers import extract_zip +from .file_upload_helpers import git_clone \ No newline at end of file diff --git a/agency_swarm/util/helpers/file_upload_helpers.py b/agency_swarm/util/helpers/file_upload_helpers.py new file mode 100644 index 00000000..7b9799f5 --- /dev/null +++ b/agency_swarm/util/helpers/file_upload_helpers.py @@ -0,0 +1,127 @@ +import os +import tarfile +import tempfile +import zipfile +from typing import List + +from git import Repo + +excluded_folders = ['node_modules', 'venv', 'vendor', '.git', '.idea'] +supported_extensions = [ + # Text Files + ".txt", ".csv", ".json", ".xml", + + # Spreadsheet Files + ".xls", ".xlsx", + + # Document Files + ".doc", ".docx", ".pdf", + + # Presentation Files + ".ppt", ".pptx", + + # Image Files + ".jpg", ".jpeg", ".png", ".gif", ".bmp", + + # Compressed Files + ".zip", ".rar", + + # Code Files + ".py", ".java", ".js", ".html", ".css", ".cpp", ".c", ".rb", ".php", + + # Audio Files + ".mp3", ".wav", + + # Video Files + ".mp4", ".avi", ".mkv" +] + + +def sanitize_file_name(file_path: str) -> str: + new_file_path = file_path + try: + if os.path.basename(file_path).index('.') < 1: + new_file_path += '.txt' + os.rename(file_path, new_file_path) + except ValueError as e: + new_file_path += '.txt' + os.rename(file_path, new_file_path) + return new_file_path + + +def is_file_extension_supported(file_path: str) -> bool: + for extension in supported_extensions: + if file_path.endswith(extension): + return True + + return False + + +def extract_zip(file_path: str) -> List[str]: + """ + Extracts files from a zip archive. + + Parameters: + file_path (str): The path to the zip file. + + Returns: + List[str]: A list of paths to the extracted files. + """ + extracted_files = [] + with zipfile.ZipFile(file_path, 'r') as zip_ref: + temp_dir = tempfile.mkdtemp() + zip_ref.extractall(temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + sanitized_file_name = sanitize_file_name(os.path.join(root, file)) + if is_file_extension_supported(sanitized_file_name): + extracted_files.append(sanitized_file_name) + return extracted_files + + +def git_clone(repo_url: str) -> List[str]: + """ + Clones a Git repository to a specified directory. + + Parameters: + repo_url (str): The URL of the repository to clone. + """ + cloned_files = [] + temp_dir = tempfile.mkdtemp() + try: + Repo.clone_from(repo_url, temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + sanitized_file_name = sanitize_file_name(os.path.join(root, file)) + if is_file_extension_supported(sanitized_file_name): + cloned_files.append(sanitized_file_name) + except Exception as e: + print(f"Error cloning repository: {e}") + finally: + print(f"Repository cloned into {temp_dir}") + return cloned_files + + +def extract_tar(file_path: str) -> List[str]: + """ + Extracts files from a tar.gz archive. + + Parameters: + file_path (str): The path to the tar.gz file. + + Returns: + List[str]: A list of paths to the extracted files. + """ + extracted_files = [] + with tarfile.open(file_path, 'r:gz') as tar_ref: + temp_dir = tempfile.mkdtemp() + tar_ref.extractall(temp_dir) + for root, dirs, files in os.walk(temp_dir): + dirs[:] = [d for d in dirs if d not in excluded_folders] + for file in files: + sanitized_file_name = sanitize_file_name(os.path.join(root, file)) + if is_file_extension_supported(sanitized_file_name): + extracted_files.append(sanitized_file_name) + return extracted_files diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index eefb501a..76965dcc --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ deepdiff==6.7.1 termcolor==2.4.0 python-dotenv==1.0.1 rich==13.7.1 -jsonref==1.1.0 \ No newline at end of file +jsonref==1.1.0 +gitpython==3.1.41 \ No newline at end of file From ff6cb14c1f716fa5934f5db7b84251d75a87bab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juraj=20Puchk=C3=BD?= Date: Wed, 5 Jun 2024 22:04:43 +0200 Subject: [PATCH 2/4] feat:Add additional filter of supported extensions. --- agency_swarm/agency/agency.py | 119 +++++++++--------- .../util/helpers/file_upload_helpers.py | 12 +- 2 files changed, 64 insertions(+), 67 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index eda991e3..6d002164 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -22,6 +22,7 @@ from agency_swarm.user import User from agency_swarm.util.files import determine_file_type from agency_swarm.util.helpers import extract_zip, extract_tar, git_clone +from agency_swarm.util.helpers.file_upload_helpers import is_file_extension_supported from agency_swarm.util.shared_state import SharedState from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, \ FileSearchToolCall @@ -273,35 +274,36 @@ def handle_file_clone(repo_url): print(f"Found {', '.join(extracted_files)}") for file in extracted_files: - file_type = determine_file_type(file) - purpose = "assistants" if file_type != "vision" else "vision" - tools = [{ - "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ - {"type": "file_search"}] - - with open(file, 'rb') as f: - try: - # Upload the file to OpenAI - uploaded_file = self.main_thread.client.files.create( - file=f, - purpose=purpose - ) - - if file_type == "vision": - images.append({ - "type": "image_file", - "image_file": {"file_id": uploaded_file.id} - }) - else: - attachments.append({ - "file_id": uploaded_file.id, - "tools": tools - }) - - message_file_names.append(uploaded_file.filename) - print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") - except Exception as e: - print(f"Uploading error: {e}") + if is_file_extension_supported(file): + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") return attachments except Exception as e: print(f"Error: {e}") @@ -334,35 +336,36 @@ def handle_file_upload(file_list): print(f"Found {', '.join(extracted_files)}") for file in extracted_files: - file_type = determine_file_type(file) - purpose = "assistants" if file_type != "vision" else "vision" - tools = [{ - "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ - {"type": "file_search"}] - - with open(file, 'rb') as f: - try: - # Upload the file to OpenAI - uploaded_file = self.main_thread.client.files.create( - file=f, - purpose=purpose - ) - - if file_type == "vision": - images.append({ - "type": "image_file", - "image_file": {"file_id": uploaded_file.id} - }) - else: - attachments.append({ - "file_id": uploaded_file.id, - "tools": tools - }) - - message_file_names.append(uploaded_file.filename) - print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") - except Exception as e: - print(f"Uploading error: {e}") + if is_file_extension_supported(file): + file_type = determine_file_type(file) + purpose = "assistants" if file_type != "vision" else "vision" + tools = [{ + "type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [ + {"type": "file_search"}] + + with open(file, 'rb') as f: + try: + # Upload the file to OpenAI + uploaded_file = self.main_thread.client.files.create( + file=f, + purpose=purpose + ) + + if file_type == "vision": + images.append({ + "type": "image_file", + "image_file": {"file_id": uploaded_file.id} + }) + else: + attachments.append({ + "file_id": uploaded_file.id, + "tools": tools + }) + + message_file_names.append(uploaded_file.filename) + print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}") + except Exception as e: + print(f"Uploading error: {e}") return attachments except Exception as e: print(f"Error: {e}") diff --git a/agency_swarm/util/helpers/file_upload_helpers.py b/agency_swarm/util/helpers/file_upload_helpers.py index 7b9799f5..373ee927 100644 --- a/agency_swarm/util/helpers/file_upload_helpers.py +++ b/agency_swarm/util/helpers/file_upload_helpers.py @@ -74,9 +74,7 @@ def extract_zip(file_path: str) -> List[str]: for root, dirs, files in os.walk(temp_dir): dirs[:] = [d for d in dirs if d not in excluded_folders] for file in files: - sanitized_file_name = sanitize_file_name(os.path.join(root, file)) - if is_file_extension_supported(sanitized_file_name): - extracted_files.append(sanitized_file_name) + extracted_files.append(sanitize_file_name(os.path.join(root, file))) return extracted_files @@ -94,9 +92,7 @@ def git_clone(repo_url: str) -> List[str]: for root, dirs, files in os.walk(temp_dir): dirs[:] = [d for d in dirs if d not in excluded_folders] for file in files: - sanitized_file_name = sanitize_file_name(os.path.join(root, file)) - if is_file_extension_supported(sanitized_file_name): - cloned_files.append(sanitized_file_name) + cloned_files.append(sanitize_file_name(os.path.join(root, file))) except Exception as e: print(f"Error cloning repository: {e}") finally: @@ -121,7 +117,5 @@ def extract_tar(file_path: str) -> List[str]: for root, dirs, files in os.walk(temp_dir): dirs[:] = [d for d in dirs if d not in excluded_folders] for file in files: - sanitized_file_name = sanitize_file_name(os.path.join(root, file)) - if is_file_extension_supported(sanitized_file_name): - extracted_files.append(sanitized_file_name) + extracted_files.append(sanitize_file_name(os.path.join(root, file))) return extracted_files From 2686dcb5605d65a3e2610466d0caade7b5b90357 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juraj=20Puchk=C3=BD?= Date: Wed, 5 Jun 2024 22:07:52 +0200 Subject: [PATCH 3/4] feat:Compressed files are not supported. --- agency_swarm/util/helpers/file_upload_helpers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/agency_swarm/util/helpers/file_upload_helpers.py b/agency_swarm/util/helpers/file_upload_helpers.py index 373ee927..2b61f171 100644 --- a/agency_swarm/util/helpers/file_upload_helpers.py +++ b/agency_swarm/util/helpers/file_upload_helpers.py @@ -23,9 +23,6 @@ # Image Files ".jpg", ".jpeg", ".png", ".gif", ".bmp", - # Compressed Files - ".zip", ".rar", - # Code Files ".py", ".java", ".js", ".html", ".css", ".cpp", ".c", ".rb", ".php", From 41fe10b6840077d10ca1512e6cb53a20866fd61a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juraj=20Puchk=C3=BD?= Date: Sun, 23 Jun 2024 10:10:23 +0200 Subject: [PATCH 4/4] fix:add files to files_folder --- agency_swarm/agency/agency.py | 19 +++++++++++++------ .../util/helpers/file_upload_helpers.py | 14 ++++++++------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 6d002164..f7db96b6 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -233,7 +233,14 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): uploading_files = False cloning_files = False recipient_agents = [agent.name for agent in self.main_recipients] - recipient_agent = self.main_recipients[0] + recipient_agent: Agent = self.main_recipients[0] + + if isinstance(recipient_agent.files_folder, list): + to_folder = recipient_agent.files_folder[0] + else: + to_folder = recipient_agent.files_folder + + os.makedirs(to_folder, exist_ok=True) with gr.Blocks(js=js) as demo: chatbot_queue = queue.Queue() @@ -263,11 +270,11 @@ def handle_file_clone(repo_url): try: extracted_files = [] - for file_path in git_clone(repo_url): + for file_path in git_clone(repo_url, to_folder): if file_path.endswith('.zip'): - extracted_files.extend(extract_zip(file_path)) + extracted_files.extend(extract_zip(file_path, to_folder)) elif file_path.endswith('.tar.gz'): - extracted_files.extend(extract_tar(file_path)) + extracted_files.extend(extract_tar(file_path, to_folder)) else: extracted_files.append(file_path) @@ -327,9 +334,9 @@ def handle_file_upload(file_list): file_path = file_obj.name if file_path.endswith('.zip'): - extracted_files.extend(extract_zip(file_path)) + extracted_files.extend(extract_zip(file_path, to_folder)) elif file_path.endswith('.tar.gz'): - extracted_files.extend(extract_tar(file_path)) + extracted_files.extend(extract_tar(file_path, to_folder)) else: extracted_files.append(file_path) diff --git a/agency_swarm/util/helpers/file_upload_helpers.py b/agency_swarm/util/helpers/file_upload_helpers.py index 2b61f171..91428e08 100644 --- a/agency_swarm/util/helpers/file_upload_helpers.py +++ b/agency_swarm/util/helpers/file_upload_helpers.py @@ -54,19 +54,20 @@ def is_file_extension_supported(file_path: str) -> bool: return False -def extract_zip(file_path: str) -> List[str]: +def extract_zip(file_path: str, to_folder: str) -> List[str]: """ Extracts files from a zip archive. Parameters: file_path (str): The path to the zip file. + to_folder (str): Path where should be files extracted Returns: List[str]: A list of paths to the extracted files. """ extracted_files = [] with zipfile.ZipFile(file_path, 'r') as zip_ref: - temp_dir = tempfile.mkdtemp() + temp_dir = tempfile.mkdtemp('zip', 'extracted', to_folder) zip_ref.extractall(temp_dir) for root, dirs, files in os.walk(temp_dir): dirs[:] = [d for d in dirs if d not in excluded_folders] @@ -75,15 +76,16 @@ def extract_zip(file_path: str) -> List[str]: return extracted_files -def git_clone(repo_url: str) -> List[str]: +def git_clone(repo_url: str, to_folder: str) -> List[str]: """ Clones a Git repository to a specified directory. Parameters: repo_url (str): The URL of the repository to clone. + to_folder (str): Path where should be git cloned into. """ cloned_files = [] - temp_dir = tempfile.mkdtemp() + temp_dir = tempfile.mkdtemp('git', 'cloned', to_folder) try: Repo.clone_from(repo_url, temp_dir) for root, dirs, files in os.walk(temp_dir): @@ -97,7 +99,7 @@ def git_clone(repo_url: str) -> List[str]: return cloned_files -def extract_tar(file_path: str) -> List[str]: +def extract_tar(file_path: str, to_folder: str) -> List[str]: """ Extracts files from a tar.gz archive. @@ -109,7 +111,7 @@ def extract_tar(file_path: str) -> List[str]: """ extracted_files = [] with tarfile.open(file_path, 'r:gz') as tar_ref: - temp_dir = tempfile.mkdtemp() + temp_dir = tempfile.mkdtemp('tar', 'extracted', to_folder) tar_ref.extractall(temp_dir) for root, dirs, files in os.walk(temp_dir): dirs[:] = [d for d in dirs if d not in excluded_folders]