diff --git a/vectorDB/classes/custom_loader.py b/vectorDB/classes/custom_loader.py index 9c6bf2e..c826cc2 100644 --- a/vectorDB/classes/custom_loader.py +++ b/vectorDB/classes/custom_loader.py @@ -1,24 +1,17 @@ import glob -from tqdm import tqdm +from typing import Iterator -from langchain_community.document_loaders import ( - PyPDFLoader, - TextLoader, - UnstructuredFileLoader, - UnstructuredMarkdownLoader, - UnstructuredWordDocumentLoader, -) +from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredFileLoader, UnstructuredWordDocumentLoader +from langchain_community.document_loaders.github import GithubFileLoader from langchain_core.documents.base import Document - from qdrant_client import QdrantClient +from tqdm import tqdm -from utils import compute_file_hash, retrieve_file_record, delete_file_record +from utils import compute_file_hash, delete_file_record, retrieve_file_record class CustomDirectoryLoader: - def __init__( - self, directory_path: str, client: QdrantClient, glob_pattern: str = "**" - ): + def __init__(self, directory_path: str, client: QdrantClient, glob_pattern: str = "**"): """Initialize the loader with a directory path and a glob pattern. Args: @@ -47,9 +40,7 @@ def _load_file(self, file_path: str) -> list[Document]: list[Document]: List of Document objects loaded from the file. """ file_extension = file_path.split(".")[-1] - loader_cls, loader_kwargs = self.filetype_mapping.get( - file_extension, (UnstructuredFileLoader, {}) - ) + loader_cls, loader_kwargs = self.filetype_mapping.get(file_extension, (UnstructuredFileLoader, {})) loader = loader_cls(file_path=file_path, **loader_kwargs) try: return loader.load() @@ -106,3 +97,92 @@ def load(self) -> list[Document]: except Exception as e: print(f"Error loading documents from file {file_path}: {e}") return documents + + +class CustomGithubFileLoader(GithubFileLoader): + client: QdrantClient + + class Config: + arbitrary_types_allowed = True + + def __init__(self, repo, access_token, github_api_url, branch, client: QdrantClient, file_filter=None): + """Initialize the loader with a GitHub repository and an access token. + + Args: + repo (str): GitHub repository name. + access_token (str): GitHub access token. + github_api_url (str): GitHub API URL. + branch (str): GitHub branch to load files from. + file_filter (callable, optional): Function to filter files to load. Defaults to None. + """ + + super().__init__(repo=repo, access_token=access_token, github_api_url=github_api_url, branch=branch, file_filter=file_filter, client=client) + self.client = client + + def _load_file(self, file_path: str) -> list[Document]: + """Load a single file using the GitHub file loader. + + Args: + file_path (str): Path to the file to load. + + Returns: + list[Document]: List of Document objects loaded from the file. + """ + try: + documents = super()._load_file(file_path) + return documents + except Exception as e: + print(f"Error loading file {file_path} from GitHub: {e}") + return [] + + def _search_and_load(self, file): + """Search for file metadata in Qdrant and load the file if necessary. + + Args: + file_path (str): Path to the file to search and load. + client (QdrantClient): Qdrant client instance. + + Returns: + list[Document]: List of Document objects loaded from the file. + """ + sha = file["sha"] + + if "api.github.com" in self.github_api_url: + github_url = self.github_api_url.replace("api.github.com", "github.com") + else: + github_url = self.github_api_url + source = f"{github_url}/{self.repo}/{file['type']}/{self.branch}/{file['path']}" + records = retrieve_file_record(self.client, source) + if records: + # File has already saved in vector db + old_sha = records[0].payload["metadata"]["sha"] + else: + old_sha = None + if sha != old_sha: + # File has changed, delete and update the new file + delete_file_record(self.client, source) + content = self.get_file_content_by_path(file["path"]) + if content == "": + return [] + metadata = { + "path": file["path"], + "sha": sha, + "source": f"{self.github_api_url}/{self.repo}/{file['type']}/" f"{self.branch}/{file['path']}", + } + return [Document(page_content=content, metadata=metadata)] + else: + # File has not been changed + print(f"File {source} has not been changed") + return [] + + def lazy_load(self) -> Iterator[Document]: + files = self.get_file_paths() + for file in files: + print("Processing file: ", file["path"]) + file_path = file["path"] + try: + docs = self._search_and_load(file) + if docs: + yield from docs + except Exception as e: + print(f"Error loading file {file_path}: {e}") diff --git a/vectorDB/classes/document_loader.py b/vectorDB/classes/document_loader.py index c79cce7..248609d 100644 --- a/vectorDB/classes/document_loader.py +++ b/vectorDB/classes/document_loader.py @@ -1,16 +1,14 @@ -from langchain_community.document_loaders.github import GithubFileLoader - -from classes.custom_loader import CustomDirectoryLoader +from classes.custom_loader import CustomDirectoryLoader, CustomGithubFileLoader class DocumentLoader: loaders = { "CustomDirectoryLoader": CustomDirectoryLoader, - "GithubFileLoader": GithubFileLoader, + "CustomGithubFileLoader": CustomGithubFileLoader, } metadata = { "CustomDirectoryLoader": ["source", "hash"], - "GithubFileLoader": ["source", "sha"], + "CustomGithubFileLoader": ["source", "sha"], } @staticmethod @@ -19,7 +17,7 @@ def create(config: dict): if not loader_class: raise ValueError(f"Unsupported document loader: {config['name']}") - if config["name"] == "GithubFileLoader": + if config["name"] == "CustomGithubFileLoader": file_extension = config.get("file_extension", "") file_filter = lambda file_path: file_path.endswith(file_extension) config.pop("file_extension", None) diff --git a/vectorDB/read_docs.py b/vectorDB/read_docs.py index f4061ec..447019e 100644 --- a/vectorDB/read_docs.py +++ b/vectorDB/read_docs.py @@ -34,12 +34,13 @@ def get_doc_config(source: str, host: str | None = None): } else: documentLoaderConfig = { - "name": "GithubFileLoader", + "name": "CustomGithubFileLoader", "repo": "COSCUP/COSCUP-Volunteer", "access_token": ACCESS_TOKEN, "github_api_url": "https://api.github.com", "branch": "main", "file_extension": ".md", + "client": QdrantClient(host), } return documentLoaderConfig