Skip to content

Commit

Permalink
fix: stop using global db, use get_db instead (#30)
Browse files Browse the repository at this point in the history
* Replace global db with get_db method
  • Loading branch information
raphodn authored Nov 15, 2023
1 parent 9b52406 commit 9739734
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
72 changes: 40 additions & 32 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,15 @@
from typing import Annotated

import requests
from fastapi import (
Depends,
FastAPI,
HTTPException,
Request,
Response,
UploadFile,
status,
)
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from fastapi_filter import FilterDepends
from fastapi_pagination import Page, add_pagination
from fastapi_pagination.ext.sqlalchemy import paginate
from openfoodfacts.utils import get_logger
from sqlalchemy.orm import Session

from app import crud, schemas
from app.config import settings
Expand Down Expand Up @@ -49,6 +42,16 @@
init_sentry(settings.sentry_dns)


# App database
# ------------------------------------------------------------------------------
def get_db():
db = session()
try:
yield db
finally:
db.close()


# Authentication helpers
# ------------------------------------------------------------------------------
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth")
Expand All @@ -58,9 +61,13 @@ async def create_token(user_id: str):
return f"{user_id}__U{str(uuid.uuid4())}"


async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)
):
if token and "__U" in token:
current_user: schemas.UserBase = crud.update_user_last_used_field(db, token=token) # type: ignore
current_user: schemas.UserBase = crud.update_user_last_used_field(
db, token=token
)
if current_user:
return current_user
raise HTTPException(
Expand All @@ -70,19 +77,6 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
)


# App startup & shutdown
# ------------------------------------------------------------------------------
@app.on_event("startup")
async def startup():
global db
db = session()


@app.on_event("shutdown")
async def shutdown():
db.close()


# Routes
# ------------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
Expand All @@ -95,7 +89,8 @@ def main_page(request: Request):

@app.post("/auth")
async def authentication(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db),
):
"""
Authentication: provide username/password and get a bearer token in return
Expand All @@ -117,7 +112,7 @@ async def authentication(
if r.status_code == 200:
token = await create_token(form_data.username)
user: schemas.UserBase = {"user_id": form_data.username, "token": token} # type: ignore
crud.create_user(db, user=user) # type: ignore
crud.create_user(db, user=user)
return {"access_token": token, "token_type": "bearer"}
elif r.status_code == 403:
await asyncio.sleep(2) # prevents brute-force
Expand All @@ -132,16 +127,21 @@ async def authentication(


@app.get("/prices", response_model=Page[schemas.PriceBase])
async def get_price(filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter)):
async def get_price(
filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter),
db: Session = Depends(get_db),
):
return paginate(db, crud.get_prices_query(filters=filters))


@app.post("/prices", response_model=schemas.PriceBase)
async def create_price(
price: schemas.PriceCreate,
current_user: schemas.UserBase = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Create a new price.
"""
Create a new price.
This endpoint requires authentication.
"""
Expand All @@ -162,28 +162,36 @@ async def create_price(
status_code=status.HTTP_403_FORBIDDEN,
detail="Proof does not belong to current user",
)
db_price = crud.create_price(db, price=price, user=current_user) # type: ignore
db_price = crud.create_price(db, price=price, user=current_user)
return db_price


@app.post("/proofs/upload", response_model=schemas.ProofBase)
def upload_proof(
file: UploadFile,
current_user: schemas.UserBase = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Upload a proof file.
"""
Upload a proof file.
The POST request must be a multipart/form-data request with a file field
named "file".
This endpoint requires authentication.
"""
file_path, mimetype = crud.create_proof_file(file)
db_proof = crud.create_proof(db, file_path, mimetype, user=current_user)
return db_proof


@app.get("/proofs", response_model=list[schemas.ProofBase])
def get_user_proofs(current_user: schemas.UserBase = Depends(get_current_user)):
"""Get all the proofs uploaded by the current user.
def get_user_proofs(
current_user: schemas.UserBase = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Get all the proofs uploaded by the current user.
This endpoint requires authentication.
"""
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "open-prices"
version = "0.1.0"
description = "An open database of food prices"
description = "An open database of product prices"
authors = ["Open Food Facts <[email protected]>"]
license = "AGPL-3.0 licence"
readme = "README.md"
Expand All @@ -24,7 +24,6 @@ uvicorn = "~0.23.2"
fastapi-pagination = "^0.12.12"
fastapi-filter = "^1.0.0"


[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
black = "^23.11.0"
Expand Down

0 comments on commit 9739734

Please sign in to comment.