Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped the project structure – everything's tidier now! #66

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: rag_api
channels:
- defaults
dependencies:
- python=3.11
- pip
- pip:
- -r requirements.lite.txt
100 changes: 52 additions & 48 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,75 @@
import os
import hashlib
import aiofiles
import aiofiles.os
from typing import Iterable, List
import os
from contextlib import asynccontextmanager
from shutil import copyfileobj
from typing import Iterable, List

import aiofiles
import aiofiles.os
import uvicorn
from langchain.schema import Document
from contextlib import asynccontextmanager
from dotenv import find_dotenv, load_dotenv
from fastapi.middleware.cors import CORSMiddleware
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from fastapi import (
Body,
FastAPI,
File,
Form,
Body,
HTTPException,
Query,
status,
FastAPI,
Request,
UploadFile,
HTTPException,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
WebBaseLoader,
TextLoader,
PyPDFLoader,
CSVLoader,
Docx2txtLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
)
from langchain_core.runnables.config import run_in_executor

from models import (
StoreDocument,
QueryRequestBody,
DocumentResponse,
QueryMultipleBody,
)
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, pg_health_check
from pgvector_routes import router as pgvector_router
from parsers import process_documents, clean_text
from middleware import security_middleware
from mongo import mongo_health_check
from constants import ERROR_MESSAGES
from store import AsyncPgVector

load_dotenv(find_dotenv())

from config import (
logger,
debug_mode,
CHUNK_SIZE,
from rag_api.config import (
CHUNK_OVERLAP,
vector_store,
RAG_UPLOAD_DIR,
known_source_ext,
CHUNK_SIZE,
PDF_EXTRACT_IMAGES,
LogMiddleware,
RAG_HOST,
RAG_PORT,
VectorDBType,
# RAG_EMBEDDING_MODEL,
# RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# RAG_TEMPLATE,
RAG_UPLOAD_DIR,
VECTOR_DB_TYPE,
LogMiddleware,
VectorDBType,
debug_mode,
known_source_ext,
logger,
vector_store,
)
from rag_api.api.middleware import security_middleware
from rag_api.api.models import (
DocumentResponse,
QueryMultipleBody,
QueryRequestBody,
StoreDocument,
)
from rag_api.api.pgvector_routes import router as pgvector_router
from rag_api.config.constants import ERROR_MESSAGES
from rag_api.db.mongo import mongo_health_check
from rag_api.db.psql import (
PSQLDatabase,
ensure_custom_id_index_on_embedding,
pg_health_check,
)
from rag_api.db.store import AsyncPgVector
from rag_api.utils.parsers import clean_text, process_documents

load_dotenv(find_dotenv())


@asynccontextmanager
Expand Down Expand Up @@ -318,7 +317,6 @@ def get_loader(filename: str, file_content_type: str, filepath: str):

@app.post("/local/embed")
async def embed_local_file(document: StoreDocument, request: Request):

# Check if the file exists
if not os.path.exists(document.filepath):
raise HTTPException(
Expand Down Expand Up @@ -392,6 +390,9 @@ async def embed_file(
)

try:
logger.info(
f"Received file for embedding: filename={file.filename}, content_type={file.content_type}, file_id={file_id}"
)
loader, known_type, file_ext = get_loader(
file.filename, file.content_type, temp_file_path
)
Expand All @@ -403,13 +404,15 @@ async def embed_file(
if not result:
response_status = False
response_message = "Failed to process/store the file data."
logger.error(response_message, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
elif "error" in result:
response_status = False
response_message = "Failed to process/store the file data."
logger.error(response_message, exc_info=True)
if isinstance(result["error"], str):
response_message = result["error"]
else:
Expand All @@ -420,6 +423,7 @@ async def embed_file(
except Exception as e:
response_status = False
response_message = f"Error during file processing: {str(e)}"
logger.error(response_message, exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}",
Expand Down
Empty file added rag_api/__init__.py
Empty file.
Empty file added rag_api/api/__init__.py
Empty file.
7 changes: 4 additions & 3 deletions middleware.py → rag_api/api/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from fastapi import Request
from datetime import datetime, timezone
from fastapi.responses import JSONResponse
from config import logger

import jwt
from fastapi import Request
from fastapi.responses import JSONResponse
from jwt import PyJWTError

from rag_api.config import logger


async def security_middleware(request: Request, call_next):
async def next_middleware_call():
Expand Down
3 changes: 2 additions & 1 deletion models.py → rag_api/api/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import hashlib
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel
from typing import Optional, List


class DocumentResponse(BaseModel):
Expand Down
43 changes: 28 additions & 15 deletions pgvector_routes.py → rag_api/api/pgvector_routes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi import APIRouter, HTTPException
from psql import PSQLDatabase

from rag_api.db.psql import PSQLDatabase

router = APIRouter()


async def check_index_exists(table_name: str, column_name: str) -> bool:
pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
Expand All @@ -18,15 +20,20 @@ async def check_index_exists(table_name: str, column_name: str) -> bool:
table_name,
column_name,
)
return result[0]['exists']
return result[0]["exists"]


@router.get("/test/check_index")
async def check_file_id_index(table_name: str, column_name: str):
if await check_index_exists(table_name, column_name):
return {"message": f"Index on {column_name} exists in the table {table_name}."}
else:
return HTTPException(status_code=404, detail=f"No index on {column_name} found in the table {table_name}.")

return HTTPException(
status_code=404,
detail=f"No index on {column_name} found in the table {table_name}.",
)


@router.get("/db/tables")
async def get_table_names(schema: str = "public"):
pool = await PSQLDatabase.get_pool()
Expand All @@ -40,9 +47,10 @@ async def get_table_names(schema: str = "public"):
schema,
)
# Extract table names from records
tables = [record['table_name'] for record in table_names]
tables = [record["table_name"] for record in table_names]
return {"schema": schema, "tables": tables}


@router.get("/db/tables/columns")
async def get_table_columns(table_name: str, schema: str = "public"):
pool = await PSQLDatabase.get_pool()
Expand All @@ -54,40 +62,45 @@ async def get_table_columns(table_name: str, schema: str = "public"):
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position;
""",
schema, table_name,
schema,
table_name,
)
column_names = [col['column_name'] for col in columns]
column_names = [col["column_name"] for col in columns]
return {"table_name": table_name, "columns": column_names}


@router.get("/records/all")
async def get_all_records(table_name: str):
# Validate that the table name is one of the expected ones to prevent SQL injection
if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]:
raise HTTPException(status_code=400, detail="Invalid table name")

pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
# Use SQLAlchemy core or raw SQL queries to fetch all records
records = await conn.fetch(f"SELECT * FROM {table_name};")

# Convert records to JSON serializable format, assuming records can be directly serialized
records_json = [dict(record) for record in records]

return records_json


@router.get("/records")
async def get_records_filtered_by_custom_id(custom_id: str, table_name: str = "langchain_pg_embedding"):
async def get_records_filtered_by_custom_id(
custom_id: str, table_name: str = "langchain_pg_embedding"
):
# Validate that the table name is one of the expected ones to prevent SQL injection
if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]:
raise HTTPException(status_code=400, detail="Invalid table name")

pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
# Use parameterized queries to prevent SQL Injection
query = f"SELECT * FROM {table_name} WHERE custom_id=$1;"
records = await conn.fetch(query, custom_id)

# Convert records to JSON serializable format, assuming the Record class has a dict method.
records_json = [dict(record) for record in records]
return records_json

return records_json
2 changes: 2 additions & 0 deletions rag_api/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from rag_api.config.app_config import * # noqa: F403
from rag_api.config.settings import * # noqa: F403
Loading