Skip to content

Commit

Permalink
feat: add relationship objects in response of GET /prices (#92)
Browse files Browse the repository at this point in the history
* New schema PriceFull with product & location objects. Return this for GET /prices

* Add test

* Optimize query with sqlalchemy joinedload
  • Loading branch information
raphodn authored Dec 18, 2023
1 parent e7b0c14 commit 2156690
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def authentication(
)


@app.get("/api/v1/prices", response_model=Page[schemas.PriceBase], tags=["Prices"])
@app.get("/api/v1/prices", response_model=Page[schemas.PriceFull], tags=["Prices"])
def get_price(
filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter),
db: Session = Depends(get_db),
Expand Down
10 changes: 8 additions & 2 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi import UploadFile
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import func

from app import config
Expand Down Expand Up @@ -105,9 +105,15 @@ def update_product(db: Session, product: ProductBase, update_dict: dict):

# Prices
# ------------------------------------------------------------------------------
def get_prices_query(filters: PriceFilter | None = None):
def get_prices_query(
with_join_product=True, with_join_location=True, filters: PriceFilter | None = None
):
"""Useful for pagination."""
query = select(Price)
if with_join_product:
query = query.options(joinedload(Price.product))
if with_join_location:
query = query.options(joinedload(Price.location))
if filters:
query = filters.filter(query)
return query
Expand Down
5 changes: 5 additions & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ class PriceBase(PriceCreate):
created: datetime.datetime


class PriceFull(PriceBase):
product: ProductBase | None
location: LocationBase | None


class ProofCreate(BaseModel):
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def test_get_prices():
assert len(response.json()["items"]) == 3
for price_field in ["product_id", "location_id", "proof_id"]:
assert price_field in response.json()["items"][0]
for price_relationship in ["product", "location"]:
assert price_relationship in response.json()["items"][0]


def test_get_prices_pagination():
Expand Down

0 comments on commit 2156690

Please sign in to comment.