Skip to content

Commit

Permalink
🔖 chore: Formatting (#25)
Browse files Browse the repository at this point in the history
* chore: pre-commit yaml

* chore: pre-commit test

* chore: readme
  • Loading branch information
danny-avila authored Apr 21, 2024
1 parent fa90e15 commit 4869933
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 46 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.0
hooks:
- id: black
language_version: python3
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,14 @@ In order to setup RDS Postgres with RAG API, you can follow these steps:
Notes:
* Even though you're logging with a Master user, it doesn't have all the super user privileges, that's why we cannot use the command: ```create role x with superuser;```
* If you do not enable the extension, rag_api service will throw an error that it cannot create the extension due to the note above.

### Dev notes:

#### Installing pre-commit formatter

Run the following commands to install pre-commit formatter, which uses [black](https://github.com/psf/black) code formatter:

```bash
pip install pre-commit
pre-commit install
```
117 changes: 71 additions & 46 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
)

from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, \
pg_health_check
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, pg_health_check
from middleware import security_middleware
from pgvector_routes import router as pgvector_router
from parsers import process_documents
Expand All @@ -45,7 +44,8 @@
vector_store,
RAG_UPLOAD_DIR,
known_source_ext,
PDF_EXTRACT_IMAGES, LogMiddleware,
PDF_EXTRACT_IMAGES,
LogMiddleware,
RAG_HOST,
RAG_PORT,
# RAG_EMBEDDING_MODEL,
Expand All @@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):

yield


app = FastAPI(lifespan=lifespan)

app.add_middleware(
Expand Down Expand Up @@ -140,17 +141,19 @@ async def delete_documents(ids: list[str]):
raise HTTPException(status_code=404, detail="One or more IDs not found")

file_count = len(ids)
return {"message": f"Documents for {file_count} file{'s' if file_count > 1 else ''} deleted successfully"}
return {
"message": f"Documents for {file_count} file{'s' if file_count > 1 else ''} deleted successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
if not hasattr(request.state, 'user'):
if not hasattr(request.state, "user"):
user_authorized = "public"
else:
user_authorized = request.state.user.get('id');
user_authorized = request.state.user.get("id")

authorized_documents = []
try:
Expand All @@ -162,23 +165,23 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"file_id": body.file_id}
filter={"file_id": body.file_id},
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding,
k=body.k,
filter={"file_id": body.file_id}
embedding, k=body.k, filter={"file_id": body.file_id}
)

document, score = documents[0]
doc_metadata = document.metadata
doc_user_id = doc_metadata.get('user_id')
doc_user_id = doc_metadata.get("user_id")

if doc_user_id is None or doc_user_id == user_authorized:
authorized_documents = documents
else:
logger.warn(f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}")
logger.warn(
f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}"
)

return authorized_documents
except Exception as e:
Expand All @@ -191,14 +194,15 @@ def generate_digest(page_content: str):
return hash_obj.hexdigest()


async def store_data_in_vector_db(data: Iterable[Document], file_id: str, user_id: str = '') -> bool:
async def store_data_in_vector_db(
data: Iterable[Document], file_id: str, user_id: str = ""
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
documents = text_splitter.split_documents(data)


# Preparing documents with page content and metadata for insertion.
# Preparing documents with page content and metadata for insertion.
docs = [
Document(
page_content=doc.page_content,
Expand All @@ -214,9 +218,11 @@ async def store_data_in_vector_db(data: Iterable[Document], file_id: str, user_i

try:
if isinstance(vector_store, AsyncPgVector):
ids = await vector_store.aadd_documents(docs, ids=[file_id]*len(documents))
ids = await vector_store.aadd_documents(
docs, ids=[file_id] * len(documents)
)
else:
ids = vector_store.add_documents(docs, ids=[file_id]*len(documents))
ids = vector_store.add_documents(docs, ids=[file_id] * len(documents))

return {"message": "Documents added successfully", "ids": ids}

Expand Down Expand Up @@ -273,13 +279,15 @@ async def embed_local_file(document: StoreDocument, request: Request):
detail=ERROR_MESSAGES.FILE_NOT_FOUND,
)

if not hasattr(request.state, 'user'):
if not hasattr(request.state, "user"):
user_id = "public"
else:
user_id = request.state.user.get('id');
user_id = request.state.user.get("id")

try:
loader, known_type = get_loader(document.filename, document.file_content_type, document.filepath)
loader, known_type = get_loader(
document.filename, document.file_content_type, document.filepath
)
data = loader.load()
result = await store_data_in_vector_db(data, document.file_id, user_id)

Expand Down Expand Up @@ -310,30 +318,36 @@ async def embed_local_file(document: StoreDocument, request: Request):


@app.post("/embed")
async def embed_file(request: Request, file_id: str = Form(...), file: UploadFile = File(...)):
async def embed_file(
request: Request, file_id: str = Form(...), file: UploadFile = File(...)
):
response_status = True
response_message = "File processed successfully."
known_type = None
if not hasattr(request.state, 'user'):
if not hasattr(request.state, "user"):
user_id = "public"
else:
user_id = request.state.user.get('id');
user_id = request.state.user.get("id")

temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id)
os.makedirs(temp_base_path, exist_ok=True)
temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename)

try:
async with aiofiles.open(temp_file_path, 'wb') as temp_file:
async with aiofiles.open(temp_file_path, "wb") as temp_file:
chunk_size = 64 * 1024 # 64 KB
while content := await file.read(chunk_size):
await temp_file.write(content)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}",
)

try:
loader, known_type = get_loader(file.filename, file.content_type, temp_file_path)
loader, known_type = get_loader(
file.filename, file.content_type, temp_file_path
)
data = loader.load()
result = await store_data_in_vector_db(data, file_id, user_id)

Expand All @@ -344,11 +358,11 @@ async def embed_file(request: Request, file_id: str = Form(...), file: UploadFil
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
elif 'error' in result:
elif "error" in result:
response_status = False
response_message = "Failed to process/store the file data."
if isinstance(result['error'], str):
response_message = result['error']
if isinstance(result["error"], str):
response_message = result["error"]
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -357,8 +371,10 @@ async def embed_file(request: Request, file_id: str = Form(...), file: UploadFil
except Exception as e:
response_status = False
response_message = f"Error during file processing: {str(e)}"
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}",
)
finally:
try:
await aiofiles.os.remove(temp_file_path)
Expand Down Expand Up @@ -386,7 +402,9 @@ async def load_document_context(id: str):
documents = vector_store.get_documents_by_ids(ids)

if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="The specified file_id was not found")
raise HTTPException(
status_code=404, detail="The specified file_id was not found"
)

return process_documents(documents)
except Exception as e:
Expand All @@ -398,23 +416,29 @@ async def load_document_context(id: str):


@app.post("/embed-upload")
async def embed_file_upload(request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...)):
async def embed_file_upload(
request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...)
):
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)

if not hasattr(request.state, 'user'):
if not hasattr(request.state, "user"):
user_id = "public"
else:
user_id = request.state.user.get('id');
user_id = request.state.user.get("id")

try:
with open(temp_file_path, 'wb') as temp_file:
with open(temp_file_path, "wb") as temp_file:
copyfileobj(uploaded_file.file, temp_file)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}",
)

try:
loader, known_type = get_loader(uploaded_file.filename, uploaded_file.content_type, temp_file_path)
loader, known_type = get_loader(
uploaded_file.filename, uploaded_file.content_type, temp_file_path
)

data = loader.load()
result = await store_data_in_vector_db(data, file_id, user_id)
Expand All @@ -425,8 +449,10 @@ async def embed_file_upload(request: Request, file_id: str = Form(...), uploaded
detail="Failed to process/store the file data.",
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}",
)
finally:
os.remove(temp_file_path)

Expand All @@ -452,19 +478,18 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}}
filter={"custom_id": {"$in": body.file_ids}},
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}}
embedding, k=body.k, filter={"custom_id": {"$in": body.file_ids}}
)

return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


if debug_mode:
app.include_router(router=pgvector_router)

Expand Down

0 comments on commit 4869933

Please sign in to comment.