Skip to content

Commit

Permalink
feat: locations GET endpoint (#138)
Browse files Browse the repository at this point in the history
* New endpoint GET /locations. Add tests

* Add pagination

* Add filters
  • Loading branch information
raphodn authored Jan 13, 2024
1 parent 415336c commit 300697f
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 13 deletions.
26 changes: 21 additions & 5 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def get_current_user_optional(

# Routes
# ------------------------------------------------------------------------------
@app.get("/api/v1/status")
def status_endpoint():
return {"status": "running"}


@app.post("/api/v1/auth", tags=["Auth"])
Expand Down Expand Up @@ -199,6 +202,8 @@ def authentication(
)


# Routes: Prices
# ------------------------------------------------------------------------------
def price_transformer(
prices: list[Price], current_user: schemas.UserBase | None = None
) -> list[Price]:
Expand Down Expand Up @@ -277,6 +282,8 @@ def create_price(
return db_price


# Routes: Proofs
# ------------------------------------------------------------------------------
@app.post(
"/api/v1/proofs/upload",
response_model=schemas.ProofBase,
Expand Down Expand Up @@ -327,6 +334,8 @@ def get_user_proofs(
return crud.get_user_proofs(db, user=current_user)


# Routes: Products
# ------------------------------------------------------------------------------
@app.get(
"/api/v1/products", response_model=Page[schemas.ProductBase], tags=["Products"]
)
Expand Down Expand Up @@ -367,6 +376,18 @@ def get_product_by_id(product_id: int, db: Session = Depends(get_db)):
return db_product


# Routes: Locations
# ------------------------------------------------------------------------------
@app.get(
"/api/v1/locations", response_model=Page[schemas.LocationBase], tags=["Locations"]
)
def get_locations(
filters: schemas.LocationFilter = FilterDepends(schemas.LocationFilter),
db: Session = Depends(get_db),
):
return paginate(db, crud.get_locations_query(filters=filters))


@app.get(
"/api/v1/locations/osm/{location_osm_type}/{location_osm_id}",
response_model=schemas.LocationBase,
Expand Down Expand Up @@ -401,9 +422,4 @@ def get_location_by_id(location_id: int, db: Session = Depends(get_db)):
return db_location


@app.get("/api/v1/status")
def status_endpoint():
return {"status": "running"}


add_pagination(app)
14 changes: 14 additions & 0 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.schemas import (
LocationBase,
LocationCreate,
LocationFilter,
PriceBase,
PriceCreate,
PriceFilter,
Expand Down Expand Up @@ -289,6 +290,19 @@ def create_proof_file(file: UploadFile) -> tuple[str, str]:

# Locations
# ------------------------------------------------------------------------------
def get_locations_query(filters: LocationFilter | None = None):
"""Useful for pagination."""
query = select(Location)
if filters:
query = filters.filter(query)
query = filters.sort(query)
return query


def get_locations(db: Session):
return db.execute(get_locations_query()).all()


def get_location_by_id(db: Session, id: int):
return db.query(Location).filter(Location.id == id).first()

Expand Down
5 changes: 2 additions & 3 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ class Product(Base):
unique_scans_n = Column(Integer, nullable=False, server_default="0")

prices: Mapped[list["Price"]] = relationship(back_populates="product")
price_count = Column(Integer, nullable=False, server_default="0", index=True)

created = Column(DateTime(timezone=True), server_default=func.now())
updated = Column(DateTime(timezone=True), onupdate=func.now())
price_count = Column(Integer, nullable=False, server_default="0", index=True)

__tablename__ = "products"

Expand Down Expand Up @@ -84,15 +84,14 @@ class Proof(Base):
mimetype = Column(String, index=True)

type = Column(ChoiceType(ProofTypeEnum))
is_public = Column(Boolean, nullable=False, server_default="true", index=True)

prices: Mapped[list["Price"]] = relationship(back_populates="proof")

owner = Column(String, index=True)

created = Column(DateTime(timezone=True), server_default=func.now(), index=True)

is_public = Column(Boolean, nullable=False, server_default="true", index=True)

__tablename__ = "proofs"


Expand Down
12 changes: 11 additions & 1 deletion app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)

from app.enums import CurrencyEnum, LocationOSMEnum, PricePerEnum, ProofTypeEnum
from app.models import Price, Product
from app.models import Location, Price, Product


class UserBase(BaseModel):
Expand Down Expand Up @@ -364,3 +364,13 @@ class ProductFilter(Filter):

class Constants(Filter.Constants):
model = Product


class LocationFilter(Filter):
osm_name__like: Optional[str] | None = None
osm_address_country__like: Optional[str] | None = None

order_by: Optional[list[str]] | None = None

class Constants(Filter.Constants):
model = Location
63 changes: 59 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ def override_get_db():
unique_scans_n=0,
)
LOCATION = LocationCreate(osm_id=3344841823, osm_type="NODE")
LOCATION_1 = LocationCreate(
osm_id=652825274,
osm_type="NODE",
osm_name="Monoprix",
osm_address_postcode="38000",
osm_address_city="Grenoble",
osm_address_country="France",
)
LOCATION_2 = LocationCreate(
osm_id=6509705997,
osm_type="NODE",
osm_name="Carrefour",
osm_address_postcode="1000",
osm_address_city="Bruxelles - Brussel",
osm_address_country="België / Belgique / Belgien",
)
PRICE_1 = PriceCreate(
product_code="8001505005707",
product_name="PATE NOCCIOLATA BIO 700G",
Expand Down Expand Up @@ -113,6 +129,12 @@ def clean_products(db_session):
db_session.commit()


@pytest.fixture(scope="function")
def clean_locations(db_session):
db_session.query(crud.Location).delete()
db_session.commit()


# Test prices
# ------------------------------------------------------------------------------
def test_create_price(db_session, user, clean_prices):
Expand Down Expand Up @@ -529,19 +551,19 @@ def test_get_products_pagination(clean_products):

# assert len(crud.get_products(db_session)) == 3

# # 3 prices with the same source
# # 3 products with the same source
# response = client.get("/api/v1/products?source=off")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 3
# # 1 price with a specific product_name
# # 1 product with a specific product_name
# response = client.get("/api/v1/products?product_name__like=châtaignes")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 1
# # 2 prices with the same brand
# # 2 products with the same brand
# response = client.get("/api/v1/products?brands__like=Clément Faugier")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 2
# # 2 prices with a positive unique_scans_n
# # 2 products with a positive unique_scans_n
# response = client.get("/api/v1/products?unique_scans_n__gte=1")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 2
Expand All @@ -568,6 +590,39 @@ def test_get_product(db_session, clean_products):

# Test locations
# ------------------------------------------------------------------------------
def test_get_locations(db_session, clean_locations):
crud.create_location(db_session, LOCATION_1)
crud.create_location(db_session, LOCATION_2)

assert len(crud.get_locations(db_session)) == 2
response = client.get("/api/v1/locations")
assert response.status_code == 200
assert len(response.json()["items"]) == 2


def test_get_locations_pagination(clean_locations):
response = client.get("/api/v1/locations")
assert response.status_code == 200
for key in ["items", "total", "page", "size", "pages"]:
assert key in response.json()


# def test_get_locations_filters(db_session, clean_locations):
# crud.create_location(db_session, LOCATION_1)
# crud.create_location(db_session, LOCATION_2)

# assert len(crud.get_locations(db_session)) == 2

# # 1 location Monoprix
# response = client.get("/api/v1/locations?osm_name__like=Monoprix")
# assert response.status_code == 200
# assert len(response.json()["items"]) == 1
# # 1 location in France
# response = client.get("/api/v1/locations?osm_address_country__like=France") # noqa
# assert response.status_code == 200
# assert len(response.json()["items"]) == 1


def test_get_location(location):
# by id: location exists
response = client.get(f"/api/v1/locations/{location.id}")
Expand Down

0 comments on commit 300697f

Please sign in to comment.