From 583fa319217c3d147f57fef3e016088836de6622 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 19 Mar 2024 15:26:29 -0400 Subject: [PATCH] feat: index `custom_id` for faster queries, check on startup, add PSQL helper routers, add multiple ID query --- README.md | Bin 6304 -> 6402 bytes config.py | 7 +-- main.py | 130 ++++++++++++++++++++++++++++----------------- models.py | 9 +++- pgvector_routes.py | 76 ++++++++++++++++++++++++++ psql.py | 49 +++++++++++++++++ requirements.txt | 1 + store_factory.py | 22 ++++++++ 8 files changed, 241 insertions(+), 53 deletions(-) create mode 100644 pgvector_routes.py create mode 100644 psql.py diff --git a/README.md b/README.md index 6c168c1ac616549c7e6a0892c4bd0fab2ef13e59..830321e60e0be5505ddaff10e9c3b8a05c7956cf 100644 GIT binary patch delta 106 zcmZ2r*krWffEa%~LlA=_gF8b!kPKk(oLnd%JNb@SpRo>u0z)E09zzO{E@8+AlBqym wB0~~G4v?k5PyiGw2Fj!}6ao3g4229ilNCe1=k>3{Z4)wzw%H04^yP@Bjb+ delta 22 ecmZoNT41>0fY{_B0h!5f#QGTZHlGzYWds0dH3(b) diff --git a/config.py b/config.py index 1c0112c0..b1938ca2 100644 --- a/config.py +++ b/config.py @@ -30,10 +30,11 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: PDF_EXTRACT_IMAGES = True if env_value == "true" else False CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" +DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" logger = logging.getLogger() -debug_mode = get_env_variable("DEBUG", "False").lower() == "true" +debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true" if debug_mode: logger.setLevel(logging.DEBUG) else: @@ -47,13 +48,13 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY") embeddings = OpenAIEmbeddings() -pgvector_store = get_vector_store( +vector_store = get_vector_store( connection_string=CONNECTION_STRING, embeddings=embeddings, collection_name=COLLECTION_NAME, mode="async", ) -retriever = pgvector_store.as_retriever() +retriever = vector_store.as_retriever() known_source_ext = [ "go", diff --git a/main.py b/main.py index 45d030f7..62567dc4 100644 --- a/main.py +++ b/main.py @@ -1,46 +1,57 @@ import os - import hashlib +from langchain.schema import Document +from contextlib import asynccontextmanager from dotenv import find_dotenv, load_dotenv from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware -from langchain.schema import Document from langchain_core.runnables.config import run_in_executor +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import ( + WebBaseLoader, + TextLoader, + PyPDFLoader, + CSVLoader, + Docx2txtLoader, + UnstructuredEPubLoader, + UnstructuredMarkdownLoader, + UnstructuredXMLLoader, + UnstructuredRSTLoader, + UnstructuredExcelLoader, +) +from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody +from psql import PSQLDatabase, ensure_custom_id_index_on_embedding from middleware import security_middleware -from models import DocumentModel, DocumentResponse, StoreDocument, QueryRequestBody +from pgvector_routes import router as pgvector_router +from constants import ERROR_MESSAGES from store import AsyncPgVector load_dotenv(find_dotenv()) from config import ( - PDF_EXTRACT_IMAGES, + debug_mode, CHUNK_SIZE, CHUNK_OVERLAP, - pgvector_store, + vector_store, known_source_ext, + PDF_EXTRACT_IMAGES, # RAG_EMBEDDING_MODEL, # RAG_EMBEDDING_MODEL_DEVICE_TYPE, # RAG_TEMPLATE, ) -from langchain_community.document_loaders import ( - WebBaseLoader, - TextLoader, - PyPDFLoader, - CSVLoader, - Docx2txtLoader, - UnstructuredEPubLoader, - UnstructuredMarkdownLoader, - UnstructuredXMLLoader, - UnstructuredRSTLoader, - UnstructuredExcelLoader, -) -from langchain.text_splitter import RecursiveCharacterTextSplitter +app = FastAPI() -from constants import ERROR_MESSAGES +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup logic goes here + await PSQLDatabase.get_pool() # Initialize the pool + await ensure_custom_id_index_on_embedding() + + yield # The application is now up and serving requests -app = FastAPI() +app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, @@ -56,19 +67,13 @@ app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES -def get_env_variable(var_name: str) -> str: - value = os.getenv(var_name) - if value is None: - raise ValueError(f"Environment variable '{var_name}' not found.") - return value - @app.get("/ids") async def get_all_ids(): try: - if isinstance(pgvector_store, AsyncPgVector): - ids = await pgvector_store.get_all_ids() + if isinstance(vector_store, AsyncPgVector): + ids = await vector_store.get_all_ids() else: - ids = pgvector_store.get_all_ids() + ids = vector_store.get_all_ids() return list(set(ids)) except Exception as e: @@ -78,12 +83,12 @@ async def get_all_ids(): @app.get("/documents", response_model=list[DocumentResponse]) async def get_documents_by_ids(ids: list[str]): try: - if isinstance(pgvector_store, AsyncPgVector): - existing_ids = await pgvector_store.get_all_ids() - documents = await pgvector_store.get_documents_by_ids(ids) + if isinstance(vector_store, AsyncPgVector): + existing_ids = await vector_store.get_all_ids() + documents = await vector_store.get_documents_by_ids(ids) else: - existing_ids = pgvector_store.get_all_ids() - documents = pgvector_store.get_documents_by_ids(ids) + existing_ids = vector_store.get_all_ids() + documents = vector_store.get_documents_by_ids(ids) if not all(id in existing_ids for id in ids): raise HTTPException(status_code=404, detail="One or more IDs not found") @@ -98,12 +103,12 @@ async def get_documents_by_ids(ids: list[str]): @app.delete("/documents") async def delete_documents(ids: list[str]): try: - if isinstance(pgvector_store, AsyncPgVector): - existing_ids = await pgvector_store.get_all_ids() - await pgvector_store.delete(ids=ids) + if isinstance(vector_store, AsyncPgVector): + existing_ids = await vector_store.get_all_ids() + await vector_store.delete(ids=ids) else: - existing_ids = pgvector_store.get_all_ids() - pgvector_store.delete(ids=ids) + existing_ids = vector_store.get_all_ids() + vector_store.delete(ids=ids) if not all(id in existing_ids for id in ids): raise HTTPException(status_code=404, detail="One or more IDs not found") @@ -117,22 +122,22 @@ async def delete_documents(ids: list[str]): async def query_embeddings_by_file_id(body: QueryRequestBody): try: # Get the embedding of the query text - embedding = pgvector_store.embedding_function.embed_query(body.query) + embedding = vector_store.embedding_function.embed_query(body.query) # Perform similarity search with the query embedding and filter by the file_id in metadata - if isinstance(pgvector_store, AsyncPgVector): + if isinstance(vector_store, AsyncPgVector): documents = await run_in_executor( None, - pgvector_store.similarity_search_with_score_by_vector, + vector_store.similarity_search_with_score_by_vector, embedding, k=body.k, - filter={"file_id": body.file_id} + filter={"custom_id": body.file_id} ) else: - documents = pgvector_store.similarity_search_with_score_by_vector( + documents = vector_store.similarity_search_with_score_by_vector( embedding, k=body.k, - filter={"file_id": body.file_id} + filter={"custom_id": body.file_id} ) return documents @@ -165,10 +170,10 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo ] try: - if isinstance(pgvector_store, AsyncPgVector): - ids = await pgvector_store.aadd_documents(docs, ids=[file_id]*len(documents)) + if isinstance(vector_store, AsyncPgVector): + ids = await vector_store.aadd_documents(docs, ids=[file_id]*len(documents)) else: - ids = pgvector_store.add_documents(docs, ids=[file_id]*len(documents)) + ids = vector_store.add_documents(docs, ids=[file_id]*len(documents)) return {"message": "Documents added successfully", "ids": ids} @@ -258,3 +263,32 @@ async def embed_file(document: StoreDocument): status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) + +@app.post("/query_multiple") +async def query_embeddings_by_file_ids(body: QueryMultipleBody): + try: + # Get the embedding of the query text + embedding = vector_store.embedding_function.embed_query(body.query) + + # Perform similarity search with the query embedding and filter by the file_ids in metadata + if isinstance(vector_store, AsyncPgVector): + documents = await run_in_executor( + None, + vector_store.similarity_search_with_score_by_vector, + embedding, + k=body.k, + filter={"custom_id": {"$in": body.file_ids}} + ) + else: + documents = vector_store.similarity_search_with_score_by_vector( + embedding, + k=body.k, + filter={"custom_id": {"$in": body.file_ids}} + ) + + return documents + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +if debug_mode: + app.include_router(router=pgvector_router) diff --git a/models.py b/models.py index 71403d0c..a8dd04cc 100644 --- a/models.py +++ b/models.py @@ -1,7 +1,7 @@ import hashlib from enum import Enum -from typing import Optional from pydantic import BaseModel +from typing import Optional, List class DocumentResponse(BaseModel): page_content: str @@ -29,4 +29,9 @@ class QueryRequestBody(BaseModel): class CleanupMethod(str, Enum): incremental = "incremental" - full = "full" \ No newline at end of file + full = "full" + +class QueryMultipleBody(BaseModel): + query: str + file_ids: List[str] + k: int = 4 \ No newline at end of file diff --git a/pgvector_routes.py b/pgvector_routes.py new file mode 100644 index 00000000..b1994e8f --- /dev/null +++ b/pgvector_routes.py @@ -0,0 +1,76 @@ +from fastapi import APIRouter, HTTPException +from psql import PSQLDatabase + +router = APIRouter() + +async def check_index_exists(table_name: str, column_name: str) -> bool: + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + result = await conn.fetch( + """ + SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE tablename = $1 + AND indexdef LIKE '%' || $2 || '%' + ); + """, + table_name, + column_name, + ) + return result[0]['exists'] + +@router.get("/test/check_index") +async def check_file_id_index(table_name: str, column_name: str): + if await check_index_exists(table_name, column_name): + return {"message": f"Index on {column_name} exists in the table {table_name}."} + else: + return HTTPException(status_code=404, detail=f"No index on {column_name} found in the table {table_name}.") + +@router.get("/db/tables") +async def get_table_names(schema: str = "public"): + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + table_names = await conn.fetch( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + """, + schema, + ) + # Extract table names from records + tables = [record['table_name'] for record in table_names] + return {"schema": schema, "tables": tables} + +@router.get("/db/tables/columns") +async def get_table_columns(table_name: str, schema: str = "public"): + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + columns = await conn.fetch( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position; + """, + schema, table_name, + ) + column_names = [col['column_name'] for col in columns] + return {"table_name": table_name, "columns": column_names} + +@router.get("/records/all") +async def get_all_records(table_name: str): + # Validate that the table name is one of the expected ones to prevent SQL injection + if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]: + raise HTTPException(status_code=400, detail="Invalid table name") + + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + # Use SQLAlchemy core or raw SQL queries to fetch all records + records = await conn.fetch(f"SELECT * FROM {table_name};") + + # Convert records to JSON serializable format, assuming records can be directly serialized + records_json = [dict(record) for record in records] + + return records_json diff --git a/psql.py b/psql.py new file mode 100644 index 00000000..e84ae81c --- /dev/null +++ b/psql.py @@ -0,0 +1,49 @@ +# db.py +import asyncpg +from config import DSN, logger + +class PSQLDatabase: + pool = None + + @classmethod + async def get_pool(cls): + if cls.pool is None: + cls.pool = await asyncpg.create_pool(dsn=DSN) + return cls.pool + + @classmethod + async def close_pool(cls): + if cls.pool is not None: + await cls.pool.close() + cls.pool = None + +async def ensure_custom_id_index_on_embedding(): + table_name = "langchain_pg_embedding" + column_name = "custom_id" + # You might want to standardize the index naming convention + index_name = f"idx_{table_name}_{column_name}" + + pool = await PSQLDatabase.get_pool() + async with pool.acquire() as conn: + # Check if the index exists + index_exists = await check_index_exists(conn, index_name) + + if not index_exists: + # If the index does not exist, create it + await conn.execute(f""" + CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name}); + """) + logger.debug(f"Created index '{index_name}' on '{table_name}({column_name})'") + else: + logger.debug(f"Index '{index_name}' already exists on '{table_name}({column_name})'") + +async def check_index_exists(conn, index_name: str) -> bool: + # Adjust the SQL query if necessary + result = await conn.fetchval(""" + SELECT EXISTS ( + SELECT FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.relname = $1 AND n.nspname = 'public' -- Adjust schema if necessary + ); + """, index_name) + return result \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8433bee9..1a1ac8c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ openpyxl==3.1.2 docx2txt==0.8 pypandoc==1.13 python-jose==3.3.0 +asyncpg==0.29.0 diff --git a/store_factory.py b/store_factory.py index 9fd0c432..7ea3e76c 100644 --- a/store_factory.py +++ b/store_factory.py @@ -23,3 +23,25 @@ def get_vector_store( ) else: raise ValueError("Invalid mode specified. Choose 'sync' or 'async'.") + +async def create_index_if_not_exists(conn, table_name: str, column_name: str): + # Construct index name conventionally + index_name = f"idx_{table_name}_{column_name}" + # Check if index exists + exists = await conn.fetchval(f""" + SELECT EXISTS ( + SELECT FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.relname = $1 + AND n.nspname = 'public' -- Or specify your schema if different + ); + """, index_name) + # Create the index if it does not exist + if not exists: + await conn.execute(f""" + CREATE INDEX CONCURRENTLY IF NOT EXISTS {index_name} + ON public.{table_name} ({column_name}); + """) + print(f"Index {index_name} created on {table_name}.{column_name}") + else: + print(f"Index {index_name} already exists on {table_name}.{column_name}") \ No newline at end of file