From 97397344b1e7ef7af17bc7a86445a651815a4f29 Mon Sep 17 00:00:00 2001 From: Raphael Odini Date: Wed, 15 Nov 2023 21:32:08 +0100 Subject: [PATCH] fix: stop using global db, use get_db instead (#30) * Replace global db with get_db method --- app/api.py | 72 ++++++++++++++++++++++++++++---------------------- pyproject.toml | 3 +-- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/app/api.py b/app/api.py index a4363db0..55faa0f6 100644 --- a/app/api.py +++ b/app/api.py @@ -4,15 +4,7 @@ from typing import Annotated import requests -from fastapi import ( - Depends, - FastAPI, - HTTPException, - Request, - Response, - UploadFile, - status, -) +from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates @@ -20,6 +12,7 @@ from fastapi_pagination import Page, add_pagination from fastapi_pagination.ext.sqlalchemy import paginate from openfoodfacts.utils import get_logger +from sqlalchemy.orm import Session from app import crud, schemas from app.config import settings @@ -49,6 +42,16 @@ init_sentry(settings.sentry_dns) +# App database +# ------------------------------------------------------------------------------ +def get_db(): + db = session() + try: + yield db + finally: + db.close() + + # Authentication helpers # ------------------------------------------------------------------------------ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth") @@ -58,9 +61,13 @@ async def create_token(user_id: str): return f"{user_id}__U{str(uuid.uuid4())}" -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db) +): if token and "__U" in token: - current_user: schemas.UserBase = crud.update_user_last_used_field(db, token=token) # type: ignore + current_user: schemas.UserBase = crud.update_user_last_used_field( + db, token=token + ) if current_user: return current_user raise HTTPException( @@ -70,19 +77,6 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): ) -# App startup & shutdown -# ------------------------------------------------------------------------------ -@app.on_event("startup") -async def startup(): - global db - db = session() - - -@app.on_event("shutdown") -async def shutdown(): - db.close() - - # Routes # ------------------------------------------------------------------------------ @app.get("/", response_class=HTMLResponse) @@ -95,7 +89,8 @@ def main_page(request: Request): @app.post("/auth") async def authentication( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db: Session = Depends(get_db), ): """ Authentication: provide username/password and get a bearer token in return @@ -117,7 +112,7 @@ async def authentication( if r.status_code == 200: token = await create_token(form_data.username) user: schemas.UserBase = {"user_id": form_data.username, "token": token} # type: ignore - crud.create_user(db, user=user) # type: ignore + crud.create_user(db, user=user) return {"access_token": token, "token_type": "bearer"} elif r.status_code == 403: await asyncio.sleep(2) # prevents brute-force @@ -132,7 +127,10 @@ async def authentication( @app.get("/prices", response_model=Page[schemas.PriceBase]) -async def get_price(filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter)): +async def get_price( + filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter), + db: Session = Depends(get_db), +): return paginate(db, crud.get_prices_query(filters=filters)) @@ -140,8 +138,10 @@ async def get_price(filters: schemas.PriceFilter = FilterDepends(schemas.PriceFi async def create_price( price: schemas.PriceCreate, current_user: schemas.UserBase = Depends(get_current_user), + db: Session = Depends(get_db), ): - """Create a new price. + """ + Create a new price. This endpoint requires authentication. """ @@ -162,7 +162,7 @@ async def create_price( status_code=status.HTTP_403_FORBIDDEN, detail="Proof does not belong to current user", ) - db_price = crud.create_price(db, price=price, user=current_user) # type: ignore + db_price = crud.create_price(db, price=price, user=current_user) return db_price @@ -170,11 +170,15 @@ async def create_price( def upload_proof( file: UploadFile, current_user: schemas.UserBase = Depends(get_current_user), + db: Session = Depends(get_db), ): - """Upload a proof file. + """ + Upload a proof file. The POST request must be a multipart/form-data request with a file field named "file". + + This endpoint requires authentication. """ file_path, mimetype = crud.create_proof_file(file) db_proof = crud.create_proof(db, file_path, mimetype, user=current_user) @@ -182,8 +186,12 @@ def upload_proof( @app.get("/proofs", response_model=list[schemas.ProofBase]) -def get_user_proofs(current_user: schemas.UserBase = Depends(get_current_user)): - """Get all the proofs uploaded by the current user. +def get_user_proofs( + current_user: schemas.UserBase = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + Get all the proofs uploaded by the current user. This endpoint requires authentication. """ diff --git a/pyproject.toml b/pyproject.toml index 0906f5f0..d046bd93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "open-prices" version = "0.1.0" -description = "An open database of food prices" +description = "An open database of product prices" authors = ["Open Food Facts "] license = "AGPL-3.0 licence" readme = "README.md" @@ -24,7 +24,6 @@ uvicorn = "~0.23.2" fastapi-pagination = "^0.12.12" fastapi-filter = "^1.0.0" - [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" black = "^23.11.0"