From 5ab0ba6b6dd8a04c258c5d6c54cc04b73ce25612 Mon Sep 17 00:00:00 2001 From: "Leo.Wang" Date: Wed, 11 Sep 2024 16:09:53 +0800 Subject: [PATCH] Update Gitlab query field, add query by path (#8244) --- .../builtin/gitlab/tools/gitlab_commits.py | 164 ++++++++++-------- .../builtin/gitlab/tools/gitlab_commits.yaml | 13 +- .../builtin/gitlab/tools/gitlab_files.py | 84 +++++---- .../builtin/gitlab/tools/gitlab_files.yaml | 13 +- 4 files changed, 161 insertions(+), 113 deletions(-) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index dceb37db493ee1..45ab15f437e19a 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -1,4 +1,5 @@ import json +import urllib.parse from datetime import datetime, timedelta from typing import Any, Union @@ -13,13 +14,14 @@ def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") employee = tool_parameters.get("employee", "") start_time = tool_parameters.get("start_time", "") end_time = tool_parameters.get("end_time", "") change_type = tool_parameters.get("change_type", "all") - if not project: - return self.create_text_message("Project is required") + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() @@ -35,91 +37,105 @@ def _invoke( site_url = "https://gitlab.com" # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + if repository: + result = self.fetch_commits( + site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True + ) + else: + result = self.fetch_commits( + site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False + ) return [self.create_json_message(item) for item in result] - def fetch( + def fetch_commits( self, - user_id: str, site_url: str, access_token: str, - project: str, - employee: str = None, - start_time: str = "", - end_time: str = "", - change_type: str = "", + identifier: str, + employee: str, + start_time: str, + end_time: str, + change_type: str, + is_repository: bool, ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # Get all of projects - url = f"{domain}/api/v4/projects" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() - - filtered_projects = [p for p in projects if project == "*" or p["name"] == project] - - for project in filtered_projects: - project_id = project["id"] - project_name = project["name"] - print(f"Project: {project_name}") - - # Get all of project commits - commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = {"since": start_time, "until": end_time} - if employee: - params["author"] = employee - - commits_response = requests.get(commits_url, headers=headers, params=params) - commits_response.raise_for_status() - commits = commits_response.json() - - for commit in commits: - commit_sha = commit["id"] - author_name = commit["author_name"] - + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits" + else: + # Get all projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + + filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier] + + for project in filtered_projects: + project_id = project["id"] + project_name = project["name"] + print(f"Project: {project_name}") + + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" + + params = {"since": start_time, "until": end_time} + if employee: + params["author"] = employee + + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + + for commit in commits: + commit_sha = commit["id"] + author_name = commit["author_name"] + + if is_repository: + diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff" + else: diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" - diff_response = requests.get(diff_url, headers=headers) - diff_response.raise_for_status() - diffs = diff_response.json() - - for diff in diffs: - # Calculate code lines of changed - added_lines = diff["diff"].count("\n+") - removed_lines = diff["diff"].count("\n-") - total_changes = added_lines + removed_lines - - if change_type == "new": - if added_lines > 1: - final_code = "".join( - [ - line[1:] - for line in diff["diff"].split("\n") - if line.startswith("+") and not line.startswith("+++") - ] - ) - results.append( - {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code} - ) - else: - if total_changes > 1: - final_code = "".join( - [ - line[1:] - for line in diff["diff"].split("\n") - if (line.startswith("+") or line.startswith("-")) - and not line.startswith("+++") - and not line.startswith("---") - ] - ) - final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append( - {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} - ) + + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Calculate code lines of changes + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}) + else: + if total_changes > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index d38d943958c734..669378ac97c89a 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -21,9 +21,20 @@ parameters: zh_Hans: 员工用户名 llm_description: User name for GitLab form: llm + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目名 diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 4a42b0fd7306c9..7606eee7af6cfb 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Any, Union import requests @@ -11,14 +12,14 @@ def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") branch = tool_parameters.get("branch", "") path = tool_parameters.get("path", "") - if not project: - return self.create_text_message("Project is required") + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not branch: return self.create_text_message("Branch is required") - if not path: return self.create_text_message("Path is required") @@ -30,56 +31,51 @@ def _invoke( if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): site_url = "https://gitlab.com" - # Get project ID from project name - project_id = self.get_project_id(site_url, access_token, project) - if not project_id: - return self.create_text_message(f"Project '{project}' not found.") - - # Get commit content - result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + # Get file content + if repository: + result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) + else: + result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) return [self.create_json_message(item) for item in result] - def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split("/", 1) - if len(parts) < 2: - return None, None - return parts[0], parts[1] - - def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: - headers = {"PRIVATE-TOKEN": access_token} - try: - url = f"{site_url}/api/v4/projects?search={project_name}" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() - for project in projects: - if project["name"] == project_name: - return project["id"] - except requests.RequestException as e: - print(f"Error fetching project ID from GitLab: {e}") - return None - - def fetch( - self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None + def fetch_files( + self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # List files and directories in the given path - url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" - response = requests.get(url, headers=headers) + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}" + else: + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, identifier) + if not project_id: + return self.create_text_message(f"Project '{identifier}' not found.") + tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + + response = requests.get(tree_url, headers=headers) response.raise_for_status() items = response.json() for item in items: item_path = item["path"] if item["type"] == "tree": # It's a directory - results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) + results.extend( + self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) + ) else: # It's a file - file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + if is_repository: + file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}" + else: + file_url = ( + f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + ) + file_response = requests.get(file_url, headers=headers) file_response.raise_for_status() file_content = file_response.text @@ -88,3 +84,17 @@ def fetch( print(f"Error fetching data from GitLab: {e}") return results + + def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: + headers = {"PRIVATE-TOKEN": access_token} + try: + url = f"{site_url}/api/v4/projects?search={project_name}" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + for project in projects: + if project["name"] == project_name: + return project["id"] + except requests.RequestException as e: + print(f"Error fetching project ID from GitLab: {e}") + return None diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml index d99b6254c1b99c..4c733673f15254 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -10,9 +10,20 @@ description: zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 llm: A tool for query GitLab files, Input should be a exists file or directory path. parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目