Skip to content

Commit

Permalink
refactor: add optional JWT auth, add basic logger, rename endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Mar 18, 2024
1 parent c03a9ec commit 30d09b3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 31 deletions.
8 changes: 8 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# config.py

import os
import logging
from dotenv import find_dotenv, load_dotenv
from langchain_openai import OpenAIEmbeddings
from store_factory import get_vector_store
Expand Down Expand Up @@ -30,6 +31,13 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:

CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler() # or logging.FileHandler("app.log")
handler.setFormatter(formatter)
logger.addHandler(handler)

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY")
embeddings = OpenAIEmbeddings()

Expand Down
43 changes: 12 additions & 31 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.schema import Document
from langchain_core.runnables.config import run_in_executor

from middleware import security_middleware
from models import DocumentModel, DocumentResponse, StoreDocument, QueryRequestBody
from store import AsyncPgVector

Expand Down Expand Up @@ -49,6 +50,8 @@
allow_headers=["*"],
)

app.middleware("http")(security_middleware)

app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
Expand All @@ -59,30 +62,7 @@ def get_env_variable(var_name: str) -> str:
raise ValueError(f"Environment variable '{var_name}' not found.")
return value

@app.post("/add-documents/")
async def add_documents(documents: list[DocumentModel]):
try:
docs = [
Document(
page_content=doc.page_content,
metadata={
"file_id": doc.id,
"digest": doc.generate_digest(),
**(doc.metadata or {}),
},
)
for doc in documents
]
ids = (
await pgvector_store.aadd_documents(docs, ids=[doc.id for doc in documents])
if isinstance(pgvector_store, AsyncPgVector)
else pgvector_store.add_documents(docs, ids=[doc.id for doc in documents])
)
return {"message": "Documents added successfully", "ids": ids}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/get-all-ids/")
@app.get("/ids")
async def get_all_ids():
try:
if isinstance(pgvector_store, AsyncPgVector):
Expand All @@ -95,7 +75,7 @@ async def get_all_ids():
raise HTTPException(status_code=500, detail=str(e))


@app.post("/get-documents-by-ids/", response_model=list[DocumentResponse])
@app.get("/documents", response_model=list[DocumentResponse])
async def get_documents_by_ids(ids: list[str]):
try:
if isinstance(pgvector_store, AsyncPgVector):
Expand All @@ -115,7 +95,7 @@ async def get_documents_by_ids(ids: list[str]):
raise HTTPException(status_code=500, detail=str(e))


@app.delete("/delete-documents/")
@app.delete("/documents")
async def delete_documents(ids: list[str]):
try:
if isinstance(pgvector_store, AsyncPgVector):
Expand All @@ -128,11 +108,12 @@ async def delete_documents(ids: list[str]):
if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="One or more IDs not found")

return {"message": f"{len(ids)} documents deleted successfully"}
file_count = len(ids)
return {"message": f"Documents for {file_count} file{'s' if file_count > 1 else ''} deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query-embeddings-by-file-id/")
@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody):
try:
# Get the embedding of the query text
Expand Down Expand Up @@ -238,14 +219,14 @@ def get_loader(filename: str, file_content_type: str, filepath: str):

return loader, known_type

@app.post("/doc")
async def store_doc(document: StoreDocument):
@app.post("/embed")
async def embed_file(document: StoreDocument):

# Check if the file exists
if not os.path.exists(document.filepath):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.FILE_NOT_FOUND(),
detail=ERROR_MESSAGES.FILE_NOT_FOUND,
)

try:
Expand Down
41 changes: 41 additions & 0 deletions middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from datetime import datetime, timezone
from fastapi import Request, HTTPException
from jose import jwt, JWTError
from config import logger

async def security_middleware(request: Request, call_next):
async def next():
response = await call_next(request)
return response

if request.url.path == "/docs" or request.url.path == "/openapi.json":
return await next()

jwt_secret = os.getenv('JWT_SECRET')

if jwt_secret:
authorization = request.headers.get('Authorization')
if not authorization or not authorization.startswith('Bearer '):
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")

token = authorization.split(' ')[1]
try:
payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])

# Check if the token has expired
exp_timestamp = payload.get('exp')
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
current_datetime = datetime.now(tz=timezone.utc)
if current_datetime > exp_datetime:
raise HTTPException(status_code=401, detail="Token has expired")

request.state.user = payload
logger.debug(f"{request.url.path} - {payload}")
except JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
else:
logger.warn("JWT_SECRET not found in environment variables")

return await next()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ pandas==2.2.1
openpyxl==3.1.2
docx2txt==0.8
pypandoc==1.13
python-jose==3.3.0

0 comments on commit 30d09b3

Please sign in to comment.