Skip to content

Commit

Permalink
feat: add custom GitHub loader (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
iiirischen authored Aug 16, 2024
1 parent e736a24 commit 8ea0b99
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 23 deletions.
112 changes: 96 additions & 16 deletions vectorDB/classes/custom_loader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import glob
from tqdm import tqdm
from typing import Iterator

from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredFileLoader,
UnstructuredMarkdownLoader,
UnstructuredWordDocumentLoader,
)
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredFileLoader, UnstructuredWordDocumentLoader
from langchain_community.document_loaders.github import GithubFileLoader
from langchain_core.documents.base import Document

from qdrant_client import QdrantClient
from tqdm import tqdm

from utils import compute_file_hash, retrieve_file_record, delete_file_record
from utils import compute_file_hash, delete_file_record, retrieve_file_record


class CustomDirectoryLoader:
def __init__(
self, directory_path: str, client: QdrantClient, glob_pattern: str = "**"
):
def __init__(self, directory_path: str, client: QdrantClient, glob_pattern: str = "**"):
"""Initialize the loader with a directory path and a glob pattern.
Args:
Expand Down Expand Up @@ -47,9 +40,7 @@ def _load_file(self, file_path: str) -> list[Document]:
list[Document]: List of Document objects loaded from the file.
"""
file_extension = file_path.split(".")[-1]
loader_cls, loader_kwargs = self.filetype_mapping.get(
file_extension, (UnstructuredFileLoader, {})
)
loader_cls, loader_kwargs = self.filetype_mapping.get(file_extension, (UnstructuredFileLoader, {}))
loader = loader_cls(file_path=file_path, **loader_kwargs)
try:
return loader.load()
Expand Down Expand Up @@ -106,3 +97,92 @@ def load(self) -> list[Document]:
except Exception as e:
print(f"Error loading documents from file {file_path}: {e}")
return documents


class CustomGithubFileLoader(GithubFileLoader):
client: QdrantClient

class Config:
arbitrary_types_allowed = True

def __init__(self, repo, access_token, github_api_url, branch, client: QdrantClient, file_filter=None):
"""Initialize the loader with a GitHub repository and an access token.
Args:
repo (str): GitHub repository name.
access_token (str): GitHub access token.
github_api_url (str): GitHub API URL.
branch (str): GitHub branch to load files from.
file_filter (callable, optional): Function to filter files to load. Defaults to None.
"""

super().__init__(repo=repo, access_token=access_token, github_api_url=github_api_url, branch=branch, file_filter=file_filter, client=client)
self.client = client

def _load_file(self, file_path: str) -> list[Document]:
"""Load a single file using the GitHub file loader.
Args:
file_path (str): Path to the file to load.
Returns:
list[Document]: List of Document objects loaded from the file.
"""
try:
documents = super()._load_file(file_path)
return documents
except Exception as e:
print(f"Error loading file {file_path} from GitHub: {e}")
return []

def _search_and_load(self, file):
"""Search for file metadata in Qdrant and load the file if necessary.
Args:
file_path (str): Path to the file to search and load.
client (QdrantClient): Qdrant client instance.
Returns:
list[Document]: List of Document objects loaded from the file.
"""
sha = file["sha"]

if "api.github.com" in self.github_api_url:
github_url = self.github_api_url.replace("api.github.com", "github.com")
else:
github_url = self.github_api_url
source = f"{github_url}/{self.repo}/{file['type']}/{self.branch}/{file['path']}"
records = retrieve_file_record(self.client, source)
if records:
# File has already saved in vector db
old_sha = records[0].payload["metadata"]["sha"]
else:
old_sha = None
if sha != old_sha:
# File has changed, delete and update the new file
delete_file_record(self.client, source)
content = self.get_file_content_by_path(file["path"])
if content == "":
return []
metadata = {
"path": file["path"],
"sha": sha,
"source": f"{self.github_api_url}/{self.repo}/{file['type']}/" f"{self.branch}/{file['path']}",
}
return [Document(page_content=content, metadata=metadata)]
else:
# File has not been changed
print(f"File {source} has not been changed")
return []

def lazy_load(self) -> Iterator[Document]:
files = self.get_file_paths()
for file in files:
print("Processing file: ", file["path"])
file_path = file["path"]
try:
docs = self._search_and_load(file)
if docs:
yield from docs
except Exception as e:
print(f"Error loading file {file_path}: {e}")
10 changes: 4 additions & 6 deletions vectorDB/classes/document_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from langchain_community.document_loaders.github import GithubFileLoader

from classes.custom_loader import CustomDirectoryLoader
from classes.custom_loader import CustomDirectoryLoader, CustomGithubFileLoader


class DocumentLoader:
loaders = {
"CustomDirectoryLoader": CustomDirectoryLoader,
"GithubFileLoader": GithubFileLoader,
"CustomGithubFileLoader": CustomGithubFileLoader,
}
metadata = {
"CustomDirectoryLoader": ["source", "hash"],
"GithubFileLoader": ["source", "sha"],
"CustomGithubFileLoader": ["source", "sha"],
}

@staticmethod
Expand All @@ -19,7 +17,7 @@ def create(config: dict):
if not loader_class:
raise ValueError(f"Unsupported document loader: {config['name']}")

if config["name"] == "GithubFileLoader":
if config["name"] == "CustomGithubFileLoader":
file_extension = config.get("file_extension", "")
file_filter = lambda file_path: file_path.endswith(file_extension)
config.pop("file_extension", None)
Expand Down
3 changes: 2 additions & 1 deletion vectorDB/read_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ def get_doc_config(source: str, host: str | None = None):
}
else:
documentLoaderConfig = {
"name": "GithubFileLoader",
"name": "CustomGithubFileLoader",
"repo": "COSCUP/COSCUP-Volunteer",
"access_token": ACCESS_TOKEN,
"github_api_url": "https://api.github.com",
"branch": "main",
"file_extension": ".md",
"client": QdrantClient(host),
}
return documentLoaderConfig

Expand Down

0 comments on commit 8ea0b99

Please sign in to comment.