Skip to content

Commit

Permalink
refactor: rename schemas to clarify (#146)
Browse files Browse the repository at this point in the history
* New UserCreate schema. Split UserBase

* Rename *Base to *Full
  • Loading branch information
raphodn authored Jan 14, 2024
1 parent 0f13566 commit e0e3896
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 81 deletions.
40 changes: 22 additions & 18 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_token(user_id: str):

def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)
) -> schemas.UserBase:
) -> schemas.UserCreate:
"""Get the current user if authenticated.
This function is used as a dependency in endpoints that require
Expand All @@ -119,7 +119,7 @@ def get_current_user(
def get_current_user_optional(
token: Annotated[str, Depends(oauth2_scheme_no_error)],
db: Session = Depends(get_db),
) -> schemas.UserBase | None:
) -> schemas.UserCreate | None:
"""Get the current user if authenticated, None otherwise.
This function is used as a dependency in endpoints that require
Expand Down Expand Up @@ -181,7 +181,7 @@ def authentication(
r = requests.post(settings.oauth2_server_url, data=data) # type: ignore
if r.status_code == 200:
token = create_token(form_data.username)
user = schemas.UserBase(user_id=form_data.username, token=token)
user = schemas.UserCreate(user_id=form_data.username, token=token)
db_user, created = crud.get_or_create_user(db, user=user)
user = crud.update_user_last_used_field(db, user=db_user)
# set the cookie if requested
Expand All @@ -206,7 +206,7 @@ def authentication(
# Routes: Prices
# ------------------------------------------------------------------------------
def price_transformer(
prices: list[Price], current_user: schemas.UserBase | None = None
prices: list[Price], current_user: schemas.UserCreate | None = None
) -> list[Price]:
"""Transformer function used to remove the file_path of private proofs.
Expand All @@ -229,11 +229,15 @@ def price_transformer(
return prices


@app.get("/api/v1/prices", response_model=Page[schemas.PriceFull], tags=["Prices"])
@app.get(
"/api/v1/prices",
response_model=Page[schemas.PriceFullWithRelations],
tags=["Prices"],
)
def get_price(
filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter),
db: Session = Depends(get_db),
current_user: schemas.UserBase | None = Depends(get_current_user_optional),
current_user: schemas.UserCreate | None = Depends(get_current_user_optional),
):
return paginate(
db,
Expand All @@ -244,14 +248,14 @@ def get_price(

@app.post(
"/api/v1/prices",
response_model=schemas.PriceBase,
response_model=schemas.PriceFull,
status_code=status.HTTP_201_CREATED,
tags=["Prices"],
)
def create_price(
price: schemas.PriceCreateWithValidation,
background_tasks: BackgroundTasks,
current_user: schemas.UserBase = Depends(get_current_user),
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Expand Down Expand Up @@ -288,7 +292,7 @@ def create_price(
# ------------------------------------------------------------------------------
@app.post(
"/api/v1/proofs/upload",
response_model=schemas.ProofBase,
response_model=schemas.ProofFull,
status_code=status.HTTP_201_CREATED,
tags=["Proofs"],
)
Expand All @@ -300,7 +304,7 @@ def upload_proof(
description="if true, the proof is public and is included in the API response. "
"Set false only for RECEIPT proofs that contain personal information.",
),
current_user: schemas.UserBase = Depends(get_current_user),
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Expand All @@ -323,9 +327,9 @@ def upload_proof(
return db_proof


@app.get("/api/v1/proofs", response_model=list[schemas.ProofBase], tags=["Proofs"])
@app.get("/api/v1/proofs", response_model=list[schemas.ProofFull], tags=["Proofs"])
def get_user_proofs(
current_user: schemas.UserBase = Depends(get_current_user),
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Expand All @@ -339,7 +343,7 @@ def get_user_proofs(
# Routes: Products
# ------------------------------------------------------------------------------
@app.get(
"/api/v1/products", response_model=Page[schemas.ProductBase], tags=["Products"]
"/api/v1/products", response_model=Page[schemas.ProductFull], tags=["Products"]
)
def get_products(
filters: schemas.ProductFilter = FilterDepends(schemas.ProductFilter),
Expand All @@ -350,7 +354,7 @@ def get_products(

@app.get(
"/api/v1/products/code/{product_code}",
response_model=schemas.ProductBase,
response_model=schemas.ProductFull,
tags=["Products"],
)
def get_product_by_code(product_code: str, db: Session = Depends(get_db)):
Expand All @@ -365,7 +369,7 @@ def get_product_by_code(product_code: str, db: Session = Depends(get_db)):

@app.get(
"/api/v1/products/{product_id}",
response_model=schemas.ProductBase,
response_model=schemas.ProductFull,
tags=["Products"],
)
def get_product_by_id(product_id: int, db: Session = Depends(get_db)):
Expand All @@ -381,7 +385,7 @@ def get_product_by_id(product_id: int, db: Session = Depends(get_db)):
# Routes: Locations
# ------------------------------------------------------------------------------
@app.get(
"/api/v1/locations", response_model=Page[schemas.LocationBase], tags=["Locations"]
"/api/v1/locations", response_model=Page[schemas.LocationFull], tags=["Locations"]
)
def get_locations(
filters: schemas.LocationFilter = FilterDepends(schemas.LocationFilter),
Expand All @@ -392,7 +396,7 @@ def get_locations(

@app.get(
"/api/v1/locations/osm/{location_osm_type}/{location_osm_id}",
response_model=schemas.LocationBase,
response_model=schemas.LocationFull,
tags=["Locations"],
)
def get_location_by_osm(
Expand All @@ -411,7 +415,7 @@ def get_location_by_osm(

@app.get(
"/api/v1/locations/{location_id}",
response_model=schemas.LocationBase,
response_model=schemas.LocationFull,
tags=["Locations"],
)
def get_location_by_id(location_id: int, db: Session = Depends(get_db)):
Expand Down
40 changes: 20 additions & 20 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from app.enums import LocationOSMEnum, ProofTypeEnum
from app.models import Location, Price, Product, Proof, User
from app.schemas import (
LocationBase,
LocationCreate,
LocationFilter,
PriceBase,
LocationFull,
PriceCreate,
PriceFilter,
ProductBase,
PriceFull,
ProductCreate,
ProductFilter,
UserBase,
ProductFull,
UserCreate,
)


Expand All @@ -38,7 +38,7 @@ def get_user_by_token(db: Session, token: str):
return db.query(User).filter(User.token == token).first()


def create_user(db: Session, user: UserBase) -> User:
def create_user(db: Session, user: UserCreate) -> User:
"""Create a user in the database.
:param db: the database session
Expand All @@ -52,7 +52,7 @@ def create_user(db: Session, user: UserBase) -> User:
return db_user


def get_or_create_user(db: Session, user: UserBase):
def get_or_create_user(db: Session, user: UserCreate):
created = False
db_user = get_user_by_user_id(db, user_id=user.user_id)
if not db_user:
Expand All @@ -61,19 +61,19 @@ def get_or_create_user(db: Session, user: UserBase):
return db_user, created


def update_user(db: Session, user: UserBase, update_dict: dict):
def update_user(db: Session, user: UserCreate, update_dict: dict):
for key, value in update_dict.items():
setattr(user, key, value)
db.commit()
db.refresh(user)
return user


def update_user_last_used_field(db: Session, user: UserBase) -> UserBase | None:
def update_user_last_used_field(db: Session, user: UserCreate) -> UserCreate | None:
return update_user(db, user, {"last_used": func.now()})


def increment_user_price_count(db: Session, user: UserBase):
def increment_user_price_count(db: Session, user: UserCreate):
"""Increment the price count of a user.
This is used to keep track of the number of prices linked to a user.
Expand All @@ -84,7 +84,7 @@ def increment_user_price_count(db: Session, user: UserBase):
return user


def delete_user(db: Session, user_id: UserBase):
def delete_user(db: Session, user_id: UserCreate):
db_user = get_user_by_user_id(db, user_id=user_id)
if db_user:
db.delete(db_user)
Expand Down Expand Up @@ -154,15 +154,15 @@ def get_or_create_product(
return db_product, created


def update_product(db: Session, product: ProductBase, update_dict: dict):
def update_product(db: Session, product: ProductFull, update_dict: dict):
for key, value in update_dict.items():
setattr(product, key, value)
db.commit()
db.refresh(product)
return product


def increment_product_price_count(db: Session, product: ProductBase):
def increment_product_price_count(db: Session, product: ProductFull):
"""Increment the price count of a product.
This is used to keep track of the number of prices linked to a product.
Expand Down Expand Up @@ -199,7 +199,7 @@ def get_prices(db: Session, filters: PriceFilter | None = None):
return db.execute(get_prices_query(filters=filters)).all()


def create_price(db: Session, price: PriceCreate, user: UserBase):
def create_price(db: Session, price: PriceCreate, user: UserCreate):
db_price = Price(**price.model_dump(), owner=user.user_id)
db.add(db_price)
db.commit()
Expand All @@ -208,8 +208,8 @@ def create_price(db: Session, price: PriceCreate, user: UserBase):


def link_price_product(
db: Session, price: PriceBase, product: ProductBase
) -> PriceBase:
db: Session, price: PriceFull, product: ProductFull
) -> PriceFull:
"""Link the product DB object to the price DB object and return the updated
price."""
price.product_id = product.id
Expand All @@ -218,7 +218,7 @@ def link_price_product(
return price


def set_price_location(db: Session, price: PriceBase, location: LocationBase):
def set_price_location(db: Session, price: PriceFull, location: LocationFull):
price.location_id = location.id
db.commit()
db.refresh(price)
Expand All @@ -231,7 +231,7 @@ def get_proof(db: Session, proof_id: int):
return db.query(Proof).filter(Proof.id == proof_id).first()


def get_user_proofs(db: Session, user: UserBase):
def get_user_proofs(db: Session, user: UserCreate):
return db.query(Proof).filter(Proof.owner == user.user_id).all()


Expand All @@ -240,7 +240,7 @@ def create_proof(
file_path: str,
mimetype: str,
type: ProofTypeEnum,
user: UserBase,
user: UserCreate,
is_public: bool = True,
):
"""Create a proof in the database.
Expand Down Expand Up @@ -383,15 +383,15 @@ def get_or_create_location(
return db_location, created


def update_location(db: Session, location: LocationBase, update_dict: dict):
def update_location(db: Session, location: LocationFull, update_dict: dict):
for key, value in update_dict.items():
setattr(location, key, value)
db.commit()
db.refresh(location)
return location


def increment_location_price_count(db: Session, location: LocationBase):
def increment_location_price_count(db: Session, location: LocationFull):
"""Increment the price count of a location.
This is used to keep track of the number of prices linked to a location.
Expand Down
Loading

0 comments on commit e0e3896

Please sign in to comment.