Skip to content

Commit

Permalink
feat: index custom_id for faster queries, check on startup, add PSQ…
Browse files Browse the repository at this point in the history
…L helper routers, add multiple ID query
  • Loading branch information
danny-avila committed Mar 19, 2024
1 parent efab978 commit 583fa31
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 53 deletions.
Binary file modified README.md
Binary file not shown.
7 changes: 4 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
PDF_EXTRACT_IMAGES = True if env_value == "true" else False

CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"
DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"

logger = logging.getLogger()

debug_mode = get_env_variable("DEBUG", "False").lower() == "true"
debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true"
if debug_mode:
logger.setLevel(logging.DEBUG)
else:
Expand All @@ -47,13 +48,13 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY")
embeddings = OpenAIEmbeddings()

pgvector_store = get_vector_store(
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
)
retriever = pgvector_store.as_retriever()
retriever = vector_store.as_retriever()

known_source_ext = [
"go",
Expand Down
130 changes: 82 additions & 48 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,57 @@
import os

import hashlib
from langchain.schema import Document
from contextlib import asynccontextmanager
from dotenv import find_dotenv, load_dotenv
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from langchain.schema import Document
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
WebBaseLoader,
TextLoader,
PyPDFLoader,
CSVLoader,
Docx2txtLoader,
UnstructuredEPubLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
)

from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding
from middleware import security_middleware
from models import DocumentModel, DocumentResponse, StoreDocument, QueryRequestBody
from pgvector_routes import router as pgvector_router
from constants import ERROR_MESSAGES
from store import AsyncPgVector

load_dotenv(find_dotenv())

from config import (
PDF_EXTRACT_IMAGES,
debug_mode,
CHUNK_SIZE,
CHUNK_OVERLAP,
pgvector_store,
vector_store,
known_source_ext,
PDF_EXTRACT_IMAGES,
# RAG_EMBEDDING_MODEL,
# RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# RAG_TEMPLATE,
)

from langchain_community.document_loaders import (
WebBaseLoader,
TextLoader,
PyPDFLoader,
CSVLoader,
Docx2txtLoader,
UnstructuredEPubLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
app = FastAPI()

from constants import ERROR_MESSAGES
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic goes here
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()

yield # The application is now up and serving requests

app = FastAPI()
app = FastAPI(lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
Expand All @@ -56,19 +67,13 @@
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES

def get_env_variable(var_name: str) -> str:
value = os.getenv(var_name)
if value is None:
raise ValueError(f"Environment variable '{var_name}' not found.")
return value

@app.get("/ids")
async def get_all_ids():
try:
if isinstance(pgvector_store, AsyncPgVector):
ids = await pgvector_store.get_all_ids()
if isinstance(vector_store, AsyncPgVector):
ids = await vector_store.get_all_ids()
else:
ids = pgvector_store.get_all_ids()
ids = vector_store.get_all_ids()

return list(set(ids))
except Exception as e:
Expand All @@ -78,12 +83,12 @@ async def get_all_ids():
@app.get("/documents", response_model=list[DocumentResponse])
async def get_documents_by_ids(ids: list[str]):
try:
if isinstance(pgvector_store, AsyncPgVector):
existing_ids = await pgvector_store.get_all_ids()
documents = await pgvector_store.get_documents_by_ids(ids)
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
documents = await vector_store.get_documents_by_ids(ids)
else:
existing_ids = pgvector_store.get_all_ids()
documents = pgvector_store.get_documents_by_ids(ids)
existing_ids = vector_store.get_all_ids()
documents = vector_store.get_documents_by_ids(ids)

if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="One or more IDs not found")
Expand All @@ -98,12 +103,12 @@ async def get_documents_by_ids(ids: list[str]):
@app.delete("/documents")
async def delete_documents(ids: list[str]):
try:
if isinstance(pgvector_store, AsyncPgVector):
existing_ids = await pgvector_store.get_all_ids()
await pgvector_store.delete(ids=ids)
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
await vector_store.delete(ids=ids)
else:
existing_ids = pgvector_store.get_all_ids()
pgvector_store.delete(ids=ids)
existing_ids = vector_store.get_all_ids()
vector_store.delete(ids=ids)

if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="One or more IDs not found")
Expand All @@ -117,22 +122,22 @@ async def delete_documents(ids: list[str]):
async def query_embeddings_by_file_id(body: QueryRequestBody):
try:
# Get the embedding of the query text
embedding = pgvector_store.embedding_function.embed_query(body.query)
embedding = vector_store.embedding_function.embed_query(body.query)

# Perform similarity search with the query embedding and filter by the file_id in metadata
if isinstance(pgvector_store, AsyncPgVector):
if isinstance(vector_store, AsyncPgVector):
documents = await run_in_executor(
None,
pgvector_store.similarity_search_with_score_by_vector,
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"file_id": body.file_id}
filter={"custom_id": body.file_id}
)
else:
documents = pgvector_store.similarity_search_with_score_by_vector(
documents = vector_store.similarity_search_with_score_by_vector(
embedding,
k=body.k,
filter={"file_id": body.file_id}
filter={"custom_id": body.file_id}
)

return documents
Expand Down Expand Up @@ -165,10 +170,10 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo
]

try:
if isinstance(pgvector_store, AsyncPgVector):
ids = await pgvector_store.aadd_documents(docs, ids=[file_id]*len(documents))
if isinstance(vector_store, AsyncPgVector):
ids = await vector_store.aadd_documents(docs, ids=[file_id]*len(documents))
else:
ids = pgvector_store.add_documents(docs, ids=[file_id]*len(documents))
ids = vector_store.add_documents(docs, ids=[file_id]*len(documents))

return {"message": "Documents added successfully", "ids": ids}

Expand Down Expand Up @@ -258,3 +263,32 @@ async def embed_file(document: StoreDocument):
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.post("/query_multiple")
async def query_embeddings_by_file_ids(body: QueryMultipleBody):
try:
# Get the embedding of the query text
embedding = vector_store.embedding_function.embed_query(body.query)

# Perform similarity search with the query embedding and filter by the file_ids in metadata
if isinstance(vector_store, AsyncPgVector):
documents = await run_in_executor(
None,
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}}
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}}
)

return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if debug_mode:
app.include_router(router=pgvector_router)
9 changes: 7 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from typing import Optional, List

class DocumentResponse(BaseModel):
page_content: str
Expand Down Expand Up @@ -29,4 +29,9 @@ class QueryRequestBody(BaseModel):

class CleanupMethod(str, Enum):
incremental = "incremental"
full = "full"
full = "full"

class QueryMultipleBody(BaseModel):
query: str
file_ids: List[str]
k: int = 4
76 changes: 76 additions & 0 deletions pgvector_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from fastapi import APIRouter, HTTPException
from psql import PSQLDatabase

router = APIRouter()

async def check_index_exists(table_name: str, column_name: str) -> bool:
pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
result = await conn.fetch(
"""
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE tablename = $1
AND indexdef LIKE '%' || $2 || '%'
);
""",
table_name,
column_name,
)
return result[0]['exists']

@router.get("/test/check_index")
async def check_file_id_index(table_name: str, column_name: str):
if await check_index_exists(table_name, column_name):
return {"message": f"Index on {column_name} exists in the table {table_name}."}
else:
return HTTPException(status_code=404, detail=f"No index on {column_name} found in the table {table_name}.")

@router.get("/db/tables")
async def get_table_names(schema: str = "public"):
pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
table_names = await conn.fetch(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1
""",
schema,
)
# Extract table names from records
tables = [record['table_name'] for record in table_names]
return {"schema": schema, "tables": tables}

@router.get("/db/tables/columns")
async def get_table_columns(table_name: str, schema: str = "public"):
pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
columns = await conn.fetch(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position;
""",
schema, table_name,
)
column_names = [col['column_name'] for col in columns]
return {"table_name": table_name, "columns": column_names}

@router.get("/records/all")
async def get_all_records(table_name: str):
# Validate that the table name is one of the expected ones to prevent SQL injection
if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]:
raise HTTPException(status_code=400, detail="Invalid table name")

pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
# Use SQLAlchemy core or raw SQL queries to fetch all records
records = await conn.fetch(f"SELECT * FROM {table_name};")

# Convert records to JSON serializable format, assuming records can be directly serialized
records_json = [dict(record) for record in records]

return records_json
49 changes: 49 additions & 0 deletions psql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# db.py
import asyncpg
from config import DSN, logger

class PSQLDatabase:
pool = None

@classmethod
async def get_pool(cls):
if cls.pool is None:
cls.pool = await asyncpg.create_pool(dsn=DSN)
return cls.pool

@classmethod
async def close_pool(cls):
if cls.pool is not None:
await cls.pool.close()
cls.pool = None

async def ensure_custom_id_index_on_embedding():
table_name = "langchain_pg_embedding"
column_name = "custom_id"
# You might want to standardize the index naming convention
index_name = f"idx_{table_name}_{column_name}"

pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
# Check if the index exists
index_exists = await check_index_exists(conn, index_name)

if not index_exists:
# If the index does not exist, create it
await conn.execute(f"""
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name});
""")
logger.debug(f"Created index '{index_name}' on '{table_name}({column_name})'")
else:
logger.debug(f"Index '{index_name}' already exists on '{table_name}({column_name})'")

async def check_index_exists(conn, index_name: str) -> bool:
# Adjust the SQL query if necessary
result = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = $1 AND n.nspname = 'public' -- Adjust schema if necessary
);
""", index_name)
return result
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ openpyxl==3.1.2
docx2txt==0.8
pypandoc==1.13
python-jose==3.3.0
asyncpg==0.29.0
Loading

0 comments on commit 583fa31

Please sign in to comment.