From d6edb83e488d3de23cee4878e9b47f5bd2359f28 Mon Sep 17 00:00:00 2001 From: "M.Abdulrahman Alnaseer" <20760062+abdalrohman@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:42:29 +0300 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=9A=80=20feat:=20Add=20support=20for?= =?UTF-8?q?=20other=20embedding=20providers:=20Google,=20VoyageAI,=20Coher?= =?UTF-8?q?eAI=20and=20ShuttleAI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 43 ++++++++++++++++++++++++++++++++++++++++++- requirements.lite.txt | 3 +++ requirements.txt | 3 +++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/config.py b/config.py index e4263612..270a8aac 100644 --- a/config.py +++ b/config.py @@ -28,6 +28,10 @@ class EmbeddingsProvider(Enum): HUGGINGFACE = "huggingface" HUGGINGFACETEI = "huggingfacetei" OLLAMA = "ollama" + GOOGLE = "google" + VOYAGE = "voyage" + SHUTTLEAI = "shuttleai" + COHERE = "cohere" def get_env_variable( @@ -171,7 +175,10 @@ async def dispatch(self, request, call_next): ).rstrip("/") HF_TOKEN = get_env_variable("HF_TOKEN", "") OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434") - +GOOGLE_API_KEY = get_env_variable("GOOGLE_KEY", "") +VOYAGE_API_KEY = get_env_variable("VOYAGE_API_KEY", "") +SHUTTLEAI_KEY = get_env_variable("SHUTTLEAI_KEY", "") # use embeddings from shuttleai +COHERE_API_KEY = get_env_variable("COHERE_API_KEY", "") ## Embeddings @@ -198,6 +205,31 @@ def init_embeddings(provider, model): return HuggingFaceHubEmbeddings(model=model) elif provider == EmbeddingsProvider.OLLAMA: return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) + elif provider == EmbeddingsProvider.GOOGLE: + from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings + + return GoogleGenerativeAIEmbeddings( + model=model, + api_key=GOOGLE_API_KEY, + ) + elif provider == EmbeddingsProvider.VOYAGE: + from langchain_voyageai import VoyageAIEmbeddings + + return VoyageAIEmbeddings( + model=model, + ) + elif provider == EmbeddingsProvider.SHUTTLEAI: + return OpenAIEmbeddings( + model=model, + api_key=SHUTTLEAI_KEY, + openai_api_base="https://api.shuttleai.app/v1", + ) + elif provider == EmbeddingsProvider.COHERE: + from langchain_cohere import CohereEmbeddings + + return CohereEmbeddings( + model=model, + ) else: raise ValueError(f"Unsupported embeddings provider: {provider}") @@ -220,6 +252,15 @@ def init_embeddings(provider, model): ) elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "models/embedding-001") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.VOYAGE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "voyage-large-2") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.SHUTTLEAI: + # text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-large") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.COHERE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "embed-multilingual-v3.0") else: raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") diff --git a/requirements.lite.txt b/requirements.lite.txt index 98e9fb93..552d218e 100644 --- a/requirements.lite.txt +++ b/requirements.lite.txt @@ -27,3 +27,6 @@ langchain-mongodb==0.1.3 cryptography==42.0.7 python-magic==0.4.27 python-pptx==0.6.23 +langchain-voyageai==0.1.1 +langchain-google-genai==1.0.8 +langchain-cohere==0.2.1 diff --git a/requirements.txt b/requirements.txt index edf53ce5..cc83b9d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,6 @@ langchain-mongodb==0.1.3 cryptography==42.0.7 python-magic==0.4.27 python-pptx==0.6.23 +langchain-voyageai==0.1.1 +langchain-google-genai==1.0.8 +langchain-cohere==0.2.1 From bdd5f89c781ae3f73e250935cc69dab23a9a1116 Mon Sep 17 00:00:00 2001 From: "M.Abdulrahman Alnaseer" <20760062+abdalrohman@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:45:12 +0300 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=94=A7=20fix:=20Enhance=20error=20han?= =?UTF-8?q?dling=20in=20embed=5Ffile=20with=20logging.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main.py b/main.py index c82b550a..7c13962c 100644 --- a/main.py +++ b/main.py @@ -392,6 +392,7 @@ async def embed_file( ) try: + logger.info(f"Received file for embedding: filename={file.filename}, content_type={file.content_type}, file_id={file_id}") loader, known_type, file_ext = get_loader( file.filename, file.content_type, temp_file_path ) @@ -403,6 +404,7 @@ async def embed_file( if not result: response_status = False response_message = "Failed to process/store the file data." + logger.error(response_message, exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to process/store the file data.", @@ -410,6 +412,7 @@ async def embed_file( elif "error" in result: response_status = False response_message = "Failed to process/store the file data." + logger.error(response_message, exc_info=True) if isinstance(result["error"], str): response_message = result["error"] else: @@ -420,6 +423,7 @@ async def embed_file( except Exception as e: response_status = False response_message = f"Error during file processing: {str(e)}" + logger.error(response_message, exc_info=True) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error during file processing: {str(e)}", From 39db32cf7e460af0838fdc38ab586b7475da63a7 Mon Sep 17 00:00:00 2001 From: "M.Abdulrahman Alnaseer" <20760062+abdalrohman@users.noreply.github.com> Date: Mon, 12 Aug 2024 23:22:00 +0300 Subject: [PATCH 3/5] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20Organize=20and=20up?= =?UTF-8?q?date=20packages=20in=20requirements=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Grouped packages by categories for better readability. - Updated package versions to the latest available from PyPI. - Applied changes to requirements.txt and requirements.lite.txt --- requirements.lite.txt | 69 ++++++++++++++++++++++++----------------- requirements.txt | 71 +++++++++++++++++++++++++------------------ 2 files changed, 83 insertions(+), 57 deletions(-) diff --git a/requirements.lite.txt b/requirements.lite.txt index 552d218e..8f4017a1 100644 --- a/requirements.lite.txt +++ b/requirements.lite.txt @@ -1,32 +1,45 @@ -langchain==0.1.12 -langchain_community==0.0.34 -langchain_openai==0.0.8 -langchain_core==0.1.45 -sqlalchemy==2.0.28 -python-dotenv==1.0.1 -fastapi==0.110.0 +# LangChain +langchain==0.2.12 +langchain_community==0.2.11 +langchain_openai==0.1.21 +langchain_core==0.2.29 +langchain-mongodb==0.1.8 +langchain-voyageai==0.1.1 +langchain-google-genai==1.0.8 +langchain-cohere==0.2.1 + +# API +fastapi==0.112.0 +uvicorn==0.30.5 +python-multipart==0.0.9 +aiofiles==24.1.0 + +# Database +sqlalchemy==2.0.32 psycopg2-binary==2.9.9 -pgvector==0.2.5 -uvicorn==0.28.0 -pypdf==4.1.0 -unstructured==0.12.6 -markdown==3.6 -networkx==3.2.1 -pandas==2.2.1 -openpyxl==3.1.2 +pgvector==0.3.2 +asyncpg==0.29.0 +pymongo==4.8.0 + +# Data Processing and Analysis +pandas==2.2.2 +openpyxl==3.1.5 +networkx==3.3 + +# File Handling and Parsing +pypdf==4.3.1 +unstructured==0.15.1 docx2txt==0.8 pypandoc==1.13 -PyJWT==2.8.0 -asyncpg==0.29.0 -python-multipart==0.0.9 -aiofiles==23.2.1 -rapidocr-onnxruntime==1.3.17 -opencv-python-headless==4.9.0.80 -pymongo==4.6.3 -langchain-mongodb==0.1.3 -cryptography==42.0.7 python-magic==0.4.27 -python-pptx==0.6.23 -langchain-voyageai==0.1.1 -langchain-google-genai==1.0.8 -langchain-cohere==0.2.1 +python-pptx==1.0.2 + +# Security +PyJWT==2.9.0 +cryptography==43.0.0 + +# Miscellaneous +python-dotenv==1.0.1 +markdown==3.6 +rapidocr-onnxruntime==1.3.24 +opencv-python-headless==4.10.0.84 diff --git a/requirements.txt b/requirements.txt index cc83b9d2..653a44d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,33 +1,46 @@ -langchain==0.1.12 -langchain_community==0.0.34 -langchain_openai==0.0.8 -langchain_core==0.1.45 -sqlalchemy==2.0.28 -python-dotenv==1.0.1 -fastapi==0.110.0 +# LangChain +langchain==0.2.12 +langchain_community==0.2.11 +langchain_openai==0.1.21 +langchain_core==0.2.29 +langchain-mongodb==0.1.8 +langchain-voyageai==0.1.1 +langchain-google-genai==1.0.8 +langchain-cohere==0.2.1 + +# API +fastapi==0.112.0 +uvicorn==0.30.5 +python-multipart==0.0.9 +aiofiles==24.1.0 + +# Database +sqlalchemy==2.0.32 psycopg2-binary==2.9.9 -pgvector==0.2.5 -uvicorn==0.28.0 -pypdf==4.1.0 -unstructured==0.12.6 -markdown==3.6 -networkx==3.2.1 -pandas==2.2.1 -openpyxl==3.1.2 +pgvector==0.3.2 +asyncpg==0.29.0 +pymongo==4.8.0 + +# Data Processing and Analysis +pandas==2.2.2 +openpyxl==3.1.5 +networkx==3.3 + +# File Handling and Parsing +pypdf==4.3.1 +unstructured==0.15.1 docx2txt==0.8 pypandoc==1.13 -PyJWT==2.8.0 -asyncpg==0.29.0 -python-multipart==0.0.9 -sentence_transformers==2.5.1 -aiofiles==23.2.1 -rapidocr-onnxruntime==1.3.17 -opencv-python-headless==4.9.0.80 -pymongo==4.6.3 -langchain-mongodb==0.1.3 -cryptography==42.0.7 python-magic==0.4.27 -python-pptx==0.6.23 -langchain-voyageai==0.1.1 -langchain-google-genai==1.0.8 -langchain-cohere==0.2.1 +python-pptx==1.0.2 + +# Security +PyJWT==2.9.0 +cryptography==43.0.0 + +# Miscellaneous +python-dotenv==1.0.1 +markdown==3.6 +rapidocr-onnxruntime==1.3.24 +opencv-python-headless==4.10.0.84 +sentence_transformers==3.0.1 From da6b6d7d350039bffc2be9de10d6abc6c2e6567e Mon Sep 17 00:00:00 2001 From: "M.Abdulrahman Alnaseer" <20760062+abdalrohman@users.noreply.github.com> Date: Tue, 13 Aug 2024 11:29:45 +0300 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=93=A6=20package:=20Add=20environment?= =?UTF-8?q?.yml=20for=20Conda=20environment=20setup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - use: `conda env create -f environment.yml` --- environment.yml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..8ac0e63b --- /dev/null +++ b/environment.yml @@ -0,0 +1,8 @@ +name: rag_api +channels: + - defaults +dependencies: + - python=3.11 + - pip + - pip: + - -r requirements.lite.txt From 01c6b876e63a0fd9c7bec11f467a034e1d8dc945 Mon Sep 17 00:00:00 2001 From: "M.Abdulrahman Alnaseer" <20760062+abdalrohman@users.noreply.github.com> Date: Tue, 13 Aug 2024 12:00:41 +0300 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F=20restructure:=20Impr?= =?UTF-8?q?ove=20project=20organization=20and=20modularity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Created a new `rag_api` directory to keep related parts of the code together. - Created subdirectories within `rag_api` to group related modules. This should make the code easier to understand and work with as the project grows! --- config.py | 335 ------------------ main.py | 98 ++--- rag_api/__init__.py | 0 rag_api/api/__init__.py | 0 middleware.py => rag_api/api/middleware.py | 7 +- models.py => rag_api/api/models.py | 3 +- .../api/pgvector_routes.py | 43 ++- rag_api/config/__init__.py | 2 + rag_api/config/app_config.py | 191 ++++++++++ constants.py => rag_api/config/constants.py | 2 +- rag_api/config/settings.py | 177 +++++++++ rag_api/db/__init__.py | 0 mongo.py => rag_api/db/mongo.py | 4 +- psql.py => rag_api/db/psql.py | 24 +- store.py => rag_api/db/store.py | 19 +- .../db/store_factory.py | 25 +- rag_api/utils/__init__.py | 0 parsers.py => rag_api/utils/parsers.py | 4 +- {utils => scripts}/docker/docker-build.sh | 0 {utils => scripts}/docker/docker-push.sh | 0 20 files changed, 499 insertions(+), 435 deletions(-) delete mode 100644 config.py create mode 100644 rag_api/__init__.py create mode 100644 rag_api/api/__init__.py rename middleware.py => rag_api/api/middleware.py (98%) rename models.py => rag_api/api/models.py (95%) rename pgvector_routes.py => rag_api/api/pgvector_routes.py (85%) create mode 100644 rag_api/config/__init__.py create mode 100644 rag_api/config/app_config.py rename constants.py => rag_api/config/constants.py (90%) create mode 100644 rag_api/config/settings.py create mode 100644 rag_api/db/__init__.py rename mongo.py => rag_api/db/mongo.py (89%) rename psql.py => rag_api/db/psql.py (78%) rename store.py => rag_api/db/store.py (98%) rename store_factory.py => rag_api/db/store_factory.py (81%) create mode 100644 rag_api/utils/__init__.py rename parsers.py => rag_api/utils/parsers.py (96%) rename {utils => scripts}/docker/docker-build.sh (100%) rename {utils => scripts}/docker/docker-push.sh (100%) diff --git a/config.py b/config.py deleted file mode 100644 index 270a8aac..00000000 --- a/config.py +++ /dev/null @@ -1,335 +0,0 @@ -# config.py -import os -import json -import logging -from enum import Enum -from datetime import datetime -from dotenv import find_dotenv, load_dotenv -from langchain_community.embeddings import ( - HuggingFaceEmbeddings, - HuggingFaceHubEmbeddings, - OllamaEmbeddings, -) -from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings -from starlette.middleware.base import BaseHTTPMiddleware -from store_factory import get_vector_store - -load_dotenv(find_dotenv()) - - -class VectorDBType(Enum): - PGVECTOR = "pgvector" - ATLAS_MONGO = "atlas-mongo" - - -class EmbeddingsProvider(Enum): - OPENAI = "openai" - AZURE = "azure" - HUGGINGFACE = "huggingface" - HUGGINGFACETEI = "huggingfacetei" - OLLAMA = "ollama" - GOOGLE = "google" - VOYAGE = "voyage" - SHUTTLEAI = "shuttleai" - COHERE = "cohere" - - -def get_env_variable( - var_name: str, default_value: str = None, required: bool = False -) -> str: - value = os.getenv(var_name) - if value is None: - if default_value is None and required: - raise ValueError(f"Environment variable '{var_name}' not found.") - return default_value - return value - - -RAG_HOST = os.getenv("RAG_HOST", "0.0.0.0") -RAG_PORT = int(os.getenv("RAG_PORT", 8000)) - -RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/") -if not os.path.exists(RAG_UPLOAD_DIR): - os.makedirs(RAG_UPLOAD_DIR, exist_ok=True) - -VECTOR_DB_TYPE = VectorDBType( - get_env_variable("VECTOR_DB_TYPE", VectorDBType.PGVECTOR.value) -) -POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase") -POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser") -POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword") -DB_HOST = get_env_variable("DB_HOST", "db") -DB_PORT = get_env_variable("DB_PORT", "5432") -COLLECTION_NAME = get_env_variable("COLLECTION_NAME", "testcollection") -ATLAS_MONGO_DB_URI = get_env_variable( - "ATLAS_MONGO_DB_URI", "mongodb://127.0.0.1:27018/LibreChat" -) -MONGO_VECTOR_COLLECTION = get_env_variable( - "MONGO_VECTOR_COLLECTION", "vector_collection" -) - -CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500")) -CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100")) - -env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower() -PDF_EXTRACT_IMAGES = True if env_value == "true" else False - -CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" -DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" - -## Logging - -HTTP_RES = "http_res" -HTTP_REQ = "http_req" - -logger = logging.getLogger() - -debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true" -console_json = get_env_variable("CONSOLE_JSON", "False").lower() == "true" - -if debug_mode: - logger.setLevel(logging.DEBUG) -else: - logger.setLevel(logging.INFO) - -if console_json: - - class JsonFormatter(logging.Formatter): - def __init__(self): - super(JsonFormatter, self).__init__() - - def format(self, record): - json_record = {} - - json_record["message"] = record.getMessage() - - if HTTP_REQ in record.__dict__: - json_record[HTTP_REQ] = record.__dict__[HTTP_REQ] - - if HTTP_RES in record.__dict__: - json_record[HTTP_RES] = record.__dict__[HTTP_RES] - - if record.levelno == logging.ERROR and record.exc_info: - json_record["exception"] = self.formatException(record.exc_info) - - timestamp = datetime.fromtimestamp(record.created) - json_record["timestamp"] = timestamp.isoformat() - - # add level - json_record["level"] = record.levelname - json_record["filename"] = record.filename - json_record["lineno"] = record.lineno - json_record["funcName"] = record.funcName - json_record["module"] = record.module - json_record["threadName"] = record.threadName - - return json.dumps(json_record) - - formatter = JsonFormatter() -else: - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - -handler = logging.StreamHandler() # or logging.FileHandler("app.log") -handler.setFormatter(formatter) -logger.addHandler(handler) - - -class LogMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - response = await call_next(request) - - logger_method = logger.info - - if str(request.url).endswith("/health"): - logger_method = logger.debug - - logger_method( - f"Request {request.method} {request.url} - {response.status_code}", - extra={ - HTTP_REQ: {"method": request.method, "url": str(request.url)}, - HTTP_RES: {"status_code": response.status_code}, - }, - ) - - return response - - -logging.getLogger("uvicorn.access").disabled = True - -## Credentials - -OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "") -RAG_OPENAI_API_KEY = get_env_variable("RAG_OPENAI_API_KEY", OPENAI_API_KEY) -RAG_OPENAI_BASEURL = get_env_variable("RAG_OPENAI_BASEURL", None) -RAG_OPENAI_PROXY = get_env_variable("RAG_OPENAI_PROXY", None) -AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "") -RAG_AZURE_OPENAI_API_VERSION = get_env_variable("RAG_AZURE_OPENAI_API_VERSION", None) -RAG_AZURE_OPENAI_API_KEY = get_env_variable( - "RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY -) -AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "") -RAG_AZURE_OPENAI_ENDPOINT = get_env_variable( - "RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT -).rstrip("/") -HF_TOKEN = get_env_variable("HF_TOKEN", "") -OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434") -GOOGLE_API_KEY = get_env_variable("GOOGLE_KEY", "") -VOYAGE_API_KEY = get_env_variable("VOYAGE_API_KEY", "") -SHUTTLEAI_KEY = get_env_variable("SHUTTLEAI_KEY", "") # use embeddings from shuttleai -COHERE_API_KEY = get_env_variable("COHERE_API_KEY", "") -## Embeddings - - -def init_embeddings(provider, model): - if provider == EmbeddingsProvider.OPENAI: - return OpenAIEmbeddings( - model=model, - api_key=RAG_OPENAI_API_KEY, - openai_api_base=RAG_OPENAI_BASEURL, - openai_proxy=RAG_OPENAI_PROXY, - ) - elif provider == EmbeddingsProvider.AZURE: - return AzureOpenAIEmbeddings( - azure_deployment=model, - api_key=RAG_AZURE_OPENAI_API_KEY, - azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT, - api_version=RAG_AZURE_OPENAI_API_VERSION, - ) - elif provider == EmbeddingsProvider.HUGGINGFACE: - return HuggingFaceEmbeddings( - model_name=model, encode_kwargs={"normalize_embeddings": True} - ) - elif provider == EmbeddingsProvider.HUGGINGFACETEI: - return HuggingFaceHubEmbeddings(model=model) - elif provider == EmbeddingsProvider.OLLAMA: - return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) - elif provider == EmbeddingsProvider.GOOGLE: - from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings - - return GoogleGenerativeAIEmbeddings( - model=model, - api_key=GOOGLE_API_KEY, - ) - elif provider == EmbeddingsProvider.VOYAGE: - from langchain_voyageai import VoyageAIEmbeddings - - return VoyageAIEmbeddings( - model=model, - ) - elif provider == EmbeddingsProvider.SHUTTLEAI: - return OpenAIEmbeddings( - model=model, - api_key=SHUTTLEAI_KEY, - openai_api_base="https://api.shuttleai.app/v1", - ) - elif provider == EmbeddingsProvider.COHERE: - from langchain_cohere import CohereEmbeddings - - return CohereEmbeddings( - model=model, - ) - else: - raise ValueError(f"Unsupported embeddings provider: {provider}") - - -EMBEDDINGS_PROVIDER = EmbeddingsProvider( - get_env_variable("EMBEDDINGS_PROVIDER", EmbeddingsProvider.OPENAI.value).lower() -) - -if EMBEDDINGS_PROVIDER == EmbeddingsProvider.OPENAI: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.AZURE: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACE: - EMBEDDINGS_MODEL = get_env_variable( - "EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2" - ) -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI: - EMBEDDINGS_MODEL = get_env_variable( - "EMBEDDINGS_MODEL", "http://huggingfacetei:3000" - ) -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "models/embedding-001") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.VOYAGE: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "voyage-large-2") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.SHUTTLEAI: - # text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-large") -elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.COHERE: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "embed-multilingual-v3.0") -else: - raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") - -embeddings = init_embeddings(EMBEDDINGS_PROVIDER, EMBEDDINGS_MODEL) - -logger.info(f"Initialized embeddings of type: {type(embeddings)}") - -# Vector store -if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: - vector_store = get_vector_store( - connection_string=CONNECTION_STRING, - embeddings=embeddings, - collection_name=COLLECTION_NAME, - mode="async", - ) -elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: - logger.warning("Using Atlas MongoDB as vector store is not fully supported yet.") - vector_store = get_vector_store( - connection_string=ATLAS_MONGO_DB_URI, - embeddings=embeddings, - collection_name=MONGO_VECTOR_COLLECTION, - mode="atlas-mongo", - ) -else: - raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}") - -retriever = vector_store.as_retriever() - -known_source_ext = [ - "go", - "py", - "java", - "sh", - "bat", - "ps1", - "cmd", - "js", - "ts", - "css", - "cpp", - "hpp", - "h", - "c", - "cs", - "sql", - "log", - "ini", - "pl", - "pm", - "r", - "dart", - "dockerfile", - "env", - "php", - "hs", - "hsc", - "lua", - "nginxconf", - "conf", - "m", - "mm", - "plsql", - "perl", - "rb", - "rs", - "db2", - "scala", - "bash", - "swift", - "vue", - "svelte", -] diff --git a/main.py b/main.py index 7c13962c..18ad4222 100644 --- a/main.py +++ b/main.py @@ -1,76 +1,75 @@ -import os import hashlib -import aiofiles -import aiofiles.os -from typing import Iterable, List +import os +from contextlib import asynccontextmanager from shutil import copyfileobj +from typing import Iterable, List +import aiofiles +import aiofiles.os import uvicorn -from langchain.schema import Document -from contextlib import asynccontextmanager from dotenv import find_dotenv, load_dotenv -from fastapi.middleware.cors import CORSMiddleware -from langchain_core.runnables.config import run_in_executor -from langchain.text_splitter import RecursiveCharacterTextSplitter from fastapi import ( + Body, + FastAPI, File, Form, - Body, + HTTPException, Query, - status, - FastAPI, Request, UploadFile, - HTTPException, + status, ) +from fastapi.middleware.cors import CORSMiddleware +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( - WebBaseLoader, - TextLoader, - PyPDFLoader, CSVLoader, Docx2txtLoader, + PyPDFLoader, + TextLoader, UnstructuredEPubLoader, - UnstructuredMarkdownLoader, - UnstructuredXMLLoader, - UnstructuredRSTLoader, UnstructuredExcelLoader, + UnstructuredMarkdownLoader, UnstructuredPowerPointLoader, + UnstructuredRSTLoader, + UnstructuredXMLLoader, ) +from langchain_core.runnables.config import run_in_executor -from models import ( - StoreDocument, - QueryRequestBody, - DocumentResponse, - QueryMultipleBody, -) -from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, pg_health_check -from pgvector_routes import router as pgvector_router -from parsers import process_documents, clean_text -from middleware import security_middleware -from mongo import mongo_health_check -from constants import ERROR_MESSAGES -from store import AsyncPgVector - -load_dotenv(find_dotenv()) - -from config import ( - logger, - debug_mode, - CHUNK_SIZE, +from rag_api.config import ( CHUNK_OVERLAP, - vector_store, - RAG_UPLOAD_DIR, - known_source_ext, + CHUNK_SIZE, PDF_EXTRACT_IMAGES, - LogMiddleware, RAG_HOST, RAG_PORT, - VectorDBType, - # RAG_EMBEDDING_MODEL, - # RAG_EMBEDDING_MODEL_DEVICE_TYPE, - # RAG_TEMPLATE, + RAG_UPLOAD_DIR, VECTOR_DB_TYPE, + LogMiddleware, + VectorDBType, + debug_mode, + known_source_ext, + logger, + vector_store, ) +from rag_api.api.middleware import security_middleware +from rag_api.api.models import ( + DocumentResponse, + QueryMultipleBody, + QueryRequestBody, + StoreDocument, +) +from rag_api.api.pgvector_routes import router as pgvector_router +from rag_api.config.constants import ERROR_MESSAGES +from rag_api.db.mongo import mongo_health_check +from rag_api.db.psql import ( + PSQLDatabase, + ensure_custom_id_index_on_embedding, + pg_health_check, +) +from rag_api.db.store import AsyncPgVector +from rag_api.utils.parsers import clean_text, process_documents + +load_dotenv(find_dotenv()) @asynccontextmanager @@ -318,7 +317,6 @@ def get_loader(filename: str, file_content_type: str, filepath: str): @app.post("/local/embed") async def embed_local_file(document: StoreDocument, request: Request): - # Check if the file exists if not os.path.exists(document.filepath): raise HTTPException( @@ -392,7 +390,9 @@ async def embed_file( ) try: - logger.info(f"Received file for embedding: filename={file.filename}, content_type={file.content_type}, file_id={file_id}") + logger.info( + f"Received file for embedding: filename={file.filename}, content_type={file.content_type}, file_id={file_id}" + ) loader, known_type, file_ext = get_loader( file.filename, file.content_type, temp_file_path ) diff --git a/rag_api/__init__.py b/rag_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rag_api/api/__init__.py b/rag_api/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/middleware.py b/rag_api/api/middleware.py similarity index 98% rename from middleware.py rename to rag_api/api/middleware.py index 76794478..d63c8b43 100644 --- a/middleware.py +++ b/rag_api/api/middleware.py @@ -1,12 +1,13 @@ import os -from fastapi import Request from datetime import datetime, timezone -from fastapi.responses import JSONResponse -from config import logger import jwt +from fastapi import Request +from fastapi.responses import JSONResponse from jwt import PyJWTError +from rag_api.config import logger + async def security_middleware(request: Request, call_next): async def next_middleware_call(): diff --git a/models.py b/rag_api/api/models.py similarity index 95% rename from models.py rename to rag_api/api/models.py index b584c992..dadbf1d2 100644 --- a/models.py +++ b/rag_api/api/models.py @@ -1,7 +1,8 @@ import hashlib from enum import Enum +from typing import List, Optional + from pydantic import BaseModel -from typing import Optional, List class DocumentResponse(BaseModel): diff --git a/pgvector_routes.py b/rag_api/api/pgvector_routes.py similarity index 85% rename from pgvector_routes.py rename to rag_api/api/pgvector_routes.py index 32bada0c..06fa508b 100644 --- a/pgvector_routes.py +++ b/rag_api/api/pgvector_routes.py @@ -1,8 +1,10 @@ from fastapi import APIRouter, HTTPException -from psql import PSQLDatabase + +from rag_api.db.psql import PSQLDatabase router = APIRouter() + async def check_index_exists(table_name: str, column_name: str) -> bool: pool = await PSQLDatabase.get_pool() async with pool.acquire() as conn: @@ -18,15 +20,20 @@ async def check_index_exists(table_name: str, column_name: str) -> bool: table_name, column_name, ) - return result[0]['exists'] + return result[0]["exists"] + @router.get("/test/check_index") async def check_file_id_index(table_name: str, column_name: str): if await check_index_exists(table_name, column_name): return {"message": f"Index on {column_name} exists in the table {table_name}."} else: - return HTTPException(status_code=404, detail=f"No index on {column_name} found in the table {table_name}.") - + return HTTPException( + status_code=404, + detail=f"No index on {column_name} found in the table {table_name}.", + ) + + @router.get("/db/tables") async def get_table_names(schema: str = "public"): pool = await PSQLDatabase.get_pool() @@ -40,9 +47,10 @@ async def get_table_names(schema: str = "public"): schema, ) # Extract table names from records - tables = [record['table_name'] for record in table_names] + tables = [record["table_name"] for record in table_names] return {"schema": schema, "tables": tables} + @router.get("/db/tables/columns") async def get_table_columns(table_name: str, schema: str = "public"): pool = await PSQLDatabase.get_pool() @@ -54,40 +62,45 @@ async def get_table_columns(table_name: str, schema: str = "public"): WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position; """, - schema, table_name, + schema, + table_name, ) - column_names = [col['column_name'] for col in columns] + column_names = [col["column_name"] for col in columns] return {"table_name": table_name, "columns": column_names} + @router.get("/records/all") async def get_all_records(table_name: str): # Validate that the table name is one of the expected ones to prevent SQL injection if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]: raise HTTPException(status_code=400, detail="Invalid table name") - + pool = await PSQLDatabase.get_pool() async with pool.acquire() as conn: # Use SQLAlchemy core or raw SQL queries to fetch all records records = await conn.fetch(f"SELECT * FROM {table_name};") - + # Convert records to JSON serializable format, assuming records can be directly serialized records_json = [dict(record) for record in records] - + return records_json + @router.get("/records") -async def get_records_filtered_by_custom_id(custom_id: str, table_name: str = "langchain_pg_embedding"): +async def get_records_filtered_by_custom_id( + custom_id: str, table_name: str = "langchain_pg_embedding" +): # Validate that the table name is one of the expected ones to prevent SQL injection if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]: raise HTTPException(status_code=400, detail="Invalid table name") - + pool = await PSQLDatabase.get_pool() async with pool.acquire() as conn: # Use parameterized queries to prevent SQL Injection query = f"SELECT * FROM {table_name} WHERE custom_id=$1;" records = await conn.fetch(query, custom_id) - + # Convert records to JSON serializable format, assuming the Record class has a dict method. records_json = [dict(record) for record in records] - - return records_json \ No newline at end of file + + return records_json diff --git a/rag_api/config/__init__.py b/rag_api/config/__init__.py new file mode 100644 index 00000000..e7cd2ea4 --- /dev/null +++ b/rag_api/config/__init__.py @@ -0,0 +1,2 @@ +from rag_api.config.app_config import * # noqa: F403 +from rag_api.config.settings import * # noqa: F403 diff --git a/rag_api/config/app_config.py b/rag_api/config/app_config.py new file mode 100644 index 00000000..ad1fb425 --- /dev/null +++ b/rag_api/config/app_config.py @@ -0,0 +1,191 @@ +# config.py +import json +import logging +from datetime import datetime + +from langchain_community.embeddings import ( + HuggingFaceEmbeddings, + HuggingFaceHubEmbeddings, + OllamaEmbeddings, +) +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings +from starlette.middleware.base import BaseHTTPMiddleware + +from rag_api.config.settings import ( + ATLAS_MONGO_DB_URI, + COLLECTION_NAME, + CONNECTION_STRING, + EMBEDDINGS_MODEL, + EMBEDDINGS_PROVIDER, + GOOGLE_API_KEY, + HTTP_REQ, + HTTP_RES, + MONGO_VECTOR_COLLECTION, + OLLAMA_BASE_URL, + RAG_AZURE_OPENAI_API_KEY, + RAG_AZURE_OPENAI_API_VERSION, + RAG_AZURE_OPENAI_ENDPOINT, + RAG_OPENAI_API_KEY, + RAG_OPENAI_BASEURL, + RAG_OPENAI_PROXY, + SHUTTLEAI_KEY, + VECTOR_DB_TYPE, + EmbeddingsProvider, + VectorDBType, + get_env_variable, +) +from rag_api.db.store_factory import get_vector_store + +logger = logging.getLogger() + +debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true" +console_json = get_env_variable("CONSOLE_JSON", "False").lower() == "true" + +if debug_mode: + logger.setLevel(logging.DEBUG) +else: + logger.setLevel(logging.INFO) + +if console_json: + + class JsonFormatter(logging.Formatter): + def __init__(self): + super(JsonFormatter, self).__init__() + + def format(self, record): + json_record = {} + + json_record["message"] = record.getMessage() + + if HTTP_REQ in record.__dict__: + json_record[HTTP_REQ] = record.__dict__[HTTP_REQ] + + if HTTP_RES in record.__dict__: + json_record[HTTP_RES] = record.__dict__[HTTP_RES] + + if record.levelno == logging.ERROR and record.exc_info: + json_record["exception"] = self.formatException(record.exc_info) + + timestamp = datetime.fromtimestamp(record.created) + json_record["timestamp"] = timestamp.isoformat() + + # add level + json_record["level"] = record.levelname + json_record["filename"] = record.filename + json_record["lineno"] = record.lineno + json_record["funcName"] = record.funcName + json_record["module"] = record.module + json_record["threadName"] = record.threadName + + return json.dumps(json_record) + + formatter = JsonFormatter() +else: + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + +handler = logging.StreamHandler() # or logging.FileHandler("app.log") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +class LogMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + response = await call_next(request) + + logger_method = logger.info + + if str(request.url).endswith("/health"): + logger_method = logger.debug + + logger_method( + f"Request {request.method} {request.url} - {response.status_code}", + extra={ + HTTP_REQ: {"method": request.method, "url": str(request.url)}, + HTTP_RES: {"status_code": response.status_code}, + }, + ) + + return response + + +logging.getLogger("uvicorn.access").disabled = True + + +def init_embeddings(provider, model): + if provider == EmbeddingsProvider.OPENAI: + return OpenAIEmbeddings( + model=model, + api_key=RAG_OPENAI_API_KEY, + openai_api_base=RAG_OPENAI_BASEURL, + openai_proxy=RAG_OPENAI_PROXY, + ) + elif provider == EmbeddingsProvider.AZURE: + return AzureOpenAIEmbeddings( + azure_deployment=model, + api_key=RAG_AZURE_OPENAI_API_KEY, + azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT, + api_version=RAG_AZURE_OPENAI_API_VERSION, + ) + elif provider == EmbeddingsProvider.HUGGINGFACE: + return HuggingFaceEmbeddings( + model_name=model, encode_kwargs={"normalize_embeddings": True} + ) + elif provider == EmbeddingsProvider.HUGGINGFACETEI: + return HuggingFaceHubEmbeddings(model=model) + elif provider == EmbeddingsProvider.OLLAMA: + return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) + elif provider == EmbeddingsProvider.GOOGLE: + from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings + + return GoogleGenerativeAIEmbeddings( + model=model, + api_key=GOOGLE_API_KEY, + ) + elif provider == EmbeddingsProvider.VOYAGE: + from langchain_voyageai import VoyageAIEmbeddings + + return VoyageAIEmbeddings( + model=model, + ) + elif provider == EmbeddingsProvider.SHUTTLEAI: + return OpenAIEmbeddings( + model=model, + api_key=SHUTTLEAI_KEY, + openai_api_base="https://api.shuttleai.app/v1", + ) + elif provider == EmbeddingsProvider.COHERE: + from langchain_cohere import CohereEmbeddings + + return CohereEmbeddings( + model=model, + ) + else: + raise ValueError(f"Unsupported embeddings provider: {provider}") + + +embeddings = init_embeddings(EMBEDDINGS_PROVIDER, EMBEDDINGS_MODEL) + +logger.info(f"Initialized embeddings of type: {type(embeddings)}") + +# Vector store +if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: + vector_store = get_vector_store( + connection_string=CONNECTION_STRING, + embeddings=embeddings, + collection_name=COLLECTION_NAME, + mode="async", + ) +elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: + logger.warning("Using Atlas MongoDB as vector store is not fully supported yet.") + vector_store = get_vector_store( + connection_string=ATLAS_MONGO_DB_URI, + embeddings=embeddings, + collection_name=MONGO_VECTOR_COLLECTION, + mode="atlas-mongo", + ) +else: + raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}") + +retriever = vector_store.as_retriever() diff --git a/constants.py b/rag_api/config/constants.py similarity index 90% rename from constants.py rename to rag_api/config/constants.py index 514e88bd..1eca4322 100644 --- a/constants.py +++ b/rag_api/config/constants.py @@ -13,4 +13,4 @@ def __str__(self) -> str: PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" - FILE_NOT_FOUND = "The specified file was not found." \ No newline at end of file + FILE_NOT_FOUND = "The specified file was not found." diff --git a/rag_api/config/settings.py b/rag_api/config/settings.py new file mode 100644 index 00000000..7b92b82e --- /dev/null +++ b/rag_api/config/settings.py @@ -0,0 +1,177 @@ +import os +from enum import Enum + +from dotenv import find_dotenv, load_dotenv + +load_dotenv(find_dotenv()) + + +def get_env_variable( + var_name: str, default_value: str = None, required: bool = False +) -> str: + """Retrieves an environment variable with optional default value and required flag.""" + value = os.getenv(var_name) + if value is None: + if default_value is None and required: + raise ValueError(f"Environment variable '{var_name}' not found.") + return default_value + return value + + +class VectorDBType(Enum): + PGVECTOR = "pgvector" + ATLAS_MONGO = "atlas-mongo" + + +class EmbeddingsProvider(Enum): + OPENAI = "openai" + AZURE = "azure" + HUGGINGFACE = "huggingface" + HUGGINGFACETEI = "huggingfacetei" + OLLAMA = "ollama" + GOOGLE = "google" + VOYAGE = "voyage" + SHUTTLEAI = "shuttleai" + COHERE = "cohere" + + +## Logging +HTTP_RES = "http_res" +HTTP_REQ = "http_req" + +# RAG Server Configuration +RAG_HOST = os.getenv("RAG_HOST", "0.0.0.0") +RAG_PORT = int(os.getenv("RAG_PORT", 8000)) + +# Upload Directory +RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/") + +# Vector Database Configuration +VECTOR_DB_TYPE = VectorDBType( + get_env_variable("VECTOR_DB_TYPE", VectorDBType.PGVECTOR.value) +) + +# Database Credentials (Adjust based on your database) +POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase") +POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser") +POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword") +DB_HOST = get_env_variable("DB_HOST", "db") +DB_PORT = get_env_variable("DB_PORT", "5432") +COLLECTION_NAME = get_env_variable("COLLECTION_NAME", "testcollection") + +# MongoDB Atlas Configuration (if using) +ATLAS_MONGO_DB_URI = get_env_variable( + "ATLAS_MONGO_DB_URI", "mongodb://127.0.0.1:27018/LibreChat" +) +MONGO_VECTOR_COLLECTION = get_env_variable( + "MONGO_VECTOR_COLLECTION", "vector_collection" +) + +# Chunking Parameters +CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500")) +CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100")) + +# PDF Extraction +PDF_EXTRACT_IMAGES = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower() == "true" + +# Database Connection Strings +CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" +DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" +print(CONNECTION_STRING) +# Credentials +OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "") +RAG_OPENAI_API_KEY = get_env_variable("RAG_OPENAI_API_KEY", OPENAI_API_KEY) +RAG_OPENAI_BASEURL = get_env_variable("RAG_OPENAI_BASEURL", None) +RAG_OPENAI_PROXY = get_env_variable("RAG_OPENAI_PROXY", None) +AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "") +RAG_AZURE_OPENAI_API_VERSION = get_env_variable("RAG_AZURE_OPENAI_API_VERSION", None) +RAG_AZURE_OPENAI_API_KEY = get_env_variable( + "RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY +) +AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "") +RAG_AZURE_OPENAI_ENDPOINT = get_env_variable( + "RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT +).rstrip("/") +HF_TOKEN = get_env_variable("HF_TOKEN", "") +OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434") +GOOGLE_API_KEY = get_env_variable("GOOGLE_KEY", "") +VOYAGE_API_KEY = get_env_variable("VOYAGE_API_KEY", "") +SHUTTLEAI_KEY = get_env_variable("SHUTTLEAI_KEY", "") # use embeddings from shuttleai +COHERE_API_KEY = get_env_variable("COHERE_API_KEY", "") + +# Embeddings Configuration +EMBEDDINGS_PROVIDER = EmbeddingsProvider( + get_env_variable("EMBEDDINGS_PROVIDER", EmbeddingsProvider.OPENAI.value).lower() +) + +if EMBEDDINGS_PROVIDER == EmbeddingsProvider.OPENAI: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.AZURE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACE: + EMBEDDINGS_MODEL = get_env_variable( + "EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2" + ) +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI: + EMBEDDINGS_MODEL = get_env_variable( + "EMBEDDINGS_MODEL", "http://huggingfacetei:3000" + ) +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "models/embedding-001") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.VOYAGE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "voyage-large-2") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.SHUTTLEAI: + # text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-large") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.COHERE: + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "embed-multilingual-v3.0") +else: + raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") + +# Known Source File Extensions +known_source_ext = [ + "go", + "py", + "java", + "sh", + "bat", + "ps1", + "cmd", + "js", + "ts", + "css", + "cpp", + "hpp", + "h", + "c", + "cs", + "sql", + "log", + "ini", + "pl", + "pm", + "r", + "dart", + "dockerfile", + "env", + "php", + "hs", + "hsc", + "lua", + "nginxconf", + "conf", + "m", + "mm", + "plsql", + "perl", + "rb", + "rs", + "db2", + "scala", + "bash", + "swift", + "vue", + "svelte", +] diff --git a/rag_api/db/__init__.py b/rag_api/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mongo.py b/rag_api/db/mongo.py similarity index 89% rename from mongo.py rename to rag_api/db/mongo.py index 93c78ea3..8e06d8f2 100644 --- a/mongo.py +++ b/rag_api/db/mongo.py @@ -1,7 +1,9 @@ import logging + from pymongo import MongoClient from pymongo.errors import PyMongoError -from config import ATLAS_MONGO_DB_URI + +from rag_api.config import ATLAS_MONGO_DB_URI logger = logging.getLogger(__name__) diff --git a/psql.py b/rag_api/db/psql.py similarity index 78% rename from psql.py rename to rag_api/db/psql.py index 07b2fcd3..9dfcd8a6 100644 --- a/psql.py +++ b/rag_api/db/psql.py @@ -1,6 +1,7 @@ # db.py import asyncpg -from config import DSN, logger + +from rag_api.config import DSN, logger class PSQLDatabase: @@ -32,23 +33,32 @@ async def ensure_custom_id_index_on_embedding(): if not index_exists: # If the index does not exist, create it - await conn.execute(f""" + await conn.execute( + f""" CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name}); - """) - logger.debug(f"Created index '{index_name}' on '{table_name}({column_name})'") + """ + ) + logger.debug( + f"Created index '{index_name}' on '{table_name}({column_name})'" + ) else: - logger.debug(f"Index '{index_name}' already exists on '{table_name}({column_name})'") + logger.debug( + f"Index '{index_name}' already exists on '{table_name}({column_name})'" + ) async def check_index_exists(conn, index_name: str) -> bool: # Adjust the SQL query if necessary - result = await conn.fetchval(""" + result = await conn.fetchval( + """ SELECT EXISTS ( SELECT FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = $1 AND n.nspname = 'public' -- Adjust schema if necessary ); - """, index_name) + """, + index_name, + ) return result diff --git a/store.py b/rag_api/db/store.py similarity index 98% rename from store.py rename to rag_api/db/store.py index f6fd1e0e..0c5a6bfb 100644 --- a/store.py +++ b/rag_api/db/store.py @@ -1,22 +1,16 @@ -from typing import Any, Optional -from sqlalchemy import delete +import copy +from typing import Any, List, Optional, Tuple + from langchain_community.vectorstores.pgvector import PGVector from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor -from sqlalchemy.orm import Session - from langchain_mongodb import MongoDBAtlasVectorSearch -from langchain_core.embeddings import Embeddings -from typing import ( - List, - Optional, - Tuple, -) -import copy +from sqlalchemy import delete +from sqlalchemy.orm import Session class ExtendedPgVector(PGVector): - def get_all_ids(self) -> list[str]: with Session(self._bind) as session: results = session.query(self.EmbeddingStore.custom_id).all() @@ -63,7 +57,6 @@ def _delete_multiple( class AsyncPgVector(ExtendedPgVector): - async def get_all_ids(self) -> list[str]: return await run_in_executor(None, super().get_all_ids) diff --git a/store_factory.py b/rag_api/db/store_factory.py similarity index 81% rename from store_factory.py rename to rag_api/db/store_factory.py index 16b81ef6..0fc6ff6f 100644 --- a/store_factory.py +++ b/rag_api/db/store_factory.py @@ -1,9 +1,9 @@ from langchain_community.embeddings import OpenAIEmbeddings - -from store import AsyncPgVector, ExtendedPgVector -from store import AtlasMongoVector from pymongo import MongoClient +from rag_api.db.store import AsyncPgVector, AtlasMongoVector, ExtendedPgVector + + def get_vector_store( connection_string: str, embeddings: OpenAIEmbeddings, @@ -25,7 +25,9 @@ def get_vector_store( elif mode == "atlas-mongo": mongo_db = MongoClient(connection_string).get_database() mong_collection = mongo_db[collection_name] - return AtlasMongoVector(collection=mong_collection, embedding=embeddings, index_name=collection_name) + return AtlasMongoVector( + collection=mong_collection, embedding=embeddings, index_name=collection_name + ) else: raise ValueError("Invalid mode specified. Choose 'sync' or 'async'.") @@ -35,20 +37,25 @@ async def create_index_if_not_exists(conn, table_name: str, column_name: str): # Construct index name conventionally index_name = f"idx_{table_name}_{column_name}" # Check if index exists - exists = await conn.fetchval(f""" + exists = await conn.fetchval( + f""" SELECT EXISTS ( SELECT FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = $1 AND n.nspname = 'public' -- Or specify your schema if different ); - """, index_name) + """, + index_name, + ) # Create the index if it does not exist if not exists: - await conn.execute(f""" + await conn.execute( + f""" CREATE INDEX CONCURRENTLY IF NOT EXISTS {index_name} ON public.{table_name} ({column_name}); - """) + """ + ) print(f"Index {index_name} created on {table_name}.{column_name}") else: - print(f"Index {index_name} already exists on {table_name}.{column_name}") \ No newline at end of file + print(f"Index {index_name} already exists on {table_name}.{column_name}") diff --git a/rag_api/utils/__init__.py b/rag_api/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/parsers.py b/rag_api/utils/parsers.py similarity index 96% rename from parsers.py rename to rag_api/utils/parsers.py index 72fb296a..daebdc9f 100644 --- a/parsers.py +++ b/rag_api/utils/parsers.py @@ -1,6 +1,8 @@ from typing import List, Optional + from langchain.schema import Document -from config import CHUNK_OVERLAP + +from rag_api.config import CHUNK_OVERLAP def clean_text(text: str) -> str: diff --git a/utils/docker/docker-build.sh b/scripts/docker/docker-build.sh similarity index 100% rename from utils/docker/docker-build.sh rename to scripts/docker/docker-build.sh diff --git a/utils/docker/docker-push.sh b/scripts/docker/docker-push.sh similarity index 100% rename from utils/docker/docker-push.sh rename to scripts/docker/docker-push.sh