Skip to content

Commit

Permalink
feat: add upload route and RAG_UPLOAD_DIR env var
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Mar 19, 2024
1 parent 2fe80c9 commit f1ab1ab
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
.env
__pycache__
uploads/
myenv/
myenv/
venv/
Binary file modified README.md
Binary file not shown.
6 changes: 5 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
return default_value
return value

RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/")
if not os.path.exists(RAG_UPLOAD_DIR):
os.makedirs(RAG_UPLOAD_DIR, exist_ok=True)

POSTGRES_DB = get_env_variable("POSTGRES_DB")
POSTGRES_USER = get_env_variable("POSTGRES_USER")
POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD")
Expand All @@ -25,7 +29,7 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:

CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100"))
UPLOAD_DIR = get_env_variable("UPLOAD_DIR", "./uploads/")

env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower()
PDF_EXTRACT_IMAGES = True if env_value == "true" else False

Expand Down
47 changes: 41 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import hashlib
from shutil import copyfileobj
from langchain.schema import Document
from contextlib import asynccontextmanager
from dotenv import find_dotenv, load_dotenv
from fastapi import FastAPI, HTTPException, status
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
Expand Down Expand Up @@ -34,6 +35,7 @@
CHUNK_SIZE,
CHUNK_OVERLAP,
vector_store,
RAG_UPLOAD_DIR,
known_source_ext,
PDF_EXTRACT_IMAGES,
# RAG_EMBEDDING_MODEL,
Expand Down Expand Up @@ -121,29 +123,26 @@ async def delete_documents(ids: list[str]):
@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody):
try:
# Get the embedding of the query text
embedding = vector_store.embedding_function.embed_query(body.query)

# Perform similarity search with the query embedding and filter by the file_id in metadata
if isinstance(vector_store, AsyncPgVector):
documents = await run_in_executor(
None,
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"custom_id": body.file_id}
filter={"file_id": body.file_id}
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding,
k=body.k,
filter={"custom_id": body.file_id}
filter={"file_id": body.file_id}
)

return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


def generate_digest(page_content: str):
hash_obj = hashlib.md5(page_content.encode())
Expand Down Expand Up @@ -264,6 +263,42 @@ async def embed_file(document: StoreDocument):
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.post("/embed-upload")
async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile = File(...)):
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)

try:
with open(temp_file_path, 'wb') as temp_file:
copyfileobj(uploaded_file.file, temp_file)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}")

try:
loader, known_type = get_loader(uploaded_file.filename, uploaded_file.content_type, temp_file_path)

data = loader.load()
result = await store_data_in_vector_db(data, file_id)

if not result:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}")
finally:
os.remove(temp_file_path)

return {
"status": True,
"message": "File processed successfully.",
"file_id": file_id,
"filename": uploaded_file.filename,
"known_type": known_type,
}

@app.post("/query_multiple")
async def query_embeddings_by_file_ids(body: QueryMultipleBody):
try:
Expand Down
17 changes: 17 additions & 0 deletions pgvector_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,20 @@ async def get_all_records(table_name: str):
records_json = [dict(record) for record in records]

return records_json

@router.get("/records")
async def get_records_filtered_by_custom_id(custom_id: str, table_name: str = "langchain_pg_embedding"):
# Validate that the table name is one of the expected ones to prevent SQL injection
if table_name not in ["langchain_pg_collection", "langchain_pg_embedding"]:
raise HTTPException(status_code=400, detail="Invalid table name")

pool = await PSQLDatabase.get_pool()
async with pool.acquire() as conn:
# Use parameterized queries to prevent SQL Injection
query = f"SELECT * FROM {table_name} WHERE custom_id=$1;"
records = await conn.fetch(query, custom_id)

# Convert records to JSON serializable format, assuming the Record class has a dict method.
records_json = [dict(record) for record in records]

return records_json
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ docx2txt==0.8
pypandoc==1.13
python-jose==3.3.0
asyncpg==0.29.0
python-multipart==0.0.9

0 comments on commit f1ab1ab

Please sign in to comment.