diff --git a/alembic/versions/20231116_1559_1c8431a64d3a_add_price_location_relationship.py b/alembic/versions/20231116_1559_1c8431a64d3a_add_price_location_relationship.py new file mode 100644 index 00000000..e6336dac --- /dev/null +++ b/alembic/versions/20231116_1559_1c8431a64d3a_add_price_location_relationship.py @@ -0,0 +1,32 @@ +"""Add Price.location relationship + +Revision ID: 1c8431a64d3a +Revises: b3b951e016d0 +Create Date: 2023-11-16 15:59:03.881443 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1c8431a64d3a" +down_revision: Union[str, None] = "b3b951e016d0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("prices", sa.Column("location_id", sa.Integer(), nullable=True)) + op.create_foreign_key(None, "prices", "locations", ["location_id"], ["id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "prices", type_="foreignkey") + op.drop_column("prices", "location_id") + # ### end Alembic commands ### diff --git a/app/api.py b/app/api.py index 54007ef1..16d10a63 100644 --- a/app/api.py +++ b/app/api.py @@ -4,7 +4,15 @@ from typing import Annotated import requests -from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status +from fastapi import ( + BackgroundTasks, + Depends, + FastAPI, + HTTPException, + Request, + UploadFile, + status, +) from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates @@ -14,7 +22,7 @@ from openfoodfacts.utils import get_logger from sqlalchemy.orm import Session -from app import crud, schemas +from app import crud, schemas, tasks from app.config import settings from app.db import session from app.utils import init_sentry @@ -136,6 +144,7 @@ async def get_price( @app.post("/prices", response_model=schemas.PriceBase) async def create_price( price: schemas.PriceCreate, + background_tasks: BackgroundTasks, current_user: schemas.UserBase = Depends(get_current_user), db: Session = Depends(get_db), ): @@ -162,6 +171,7 @@ async def create_price( detail="Proof does not belong to current user", ) db_price = crud.create_price(db, price=price, user=current_user) + background_tasks.add_task(tasks.create_price_location, db, db_price) return db_price diff --git a/app/crud.py b/app/crud.py index a8b51a25..7b4d57e6 100644 --- a/app/crud.py +++ b/app/crud.py @@ -8,10 +8,20 @@ from sqlalchemy.sql import func from app import config -from app.models import Price, Proof, User -from app.schemas import PriceCreate, PriceFilter, UserBase - - +from app.enums import LocationOSMType +from app.models import Location, Price, Proof, User +from app.schemas import ( + LocationBase, + LocationCreate, + PriceBase, + PriceCreate, + PriceFilter, + UserBase, +) + + +# Users +# ------------------------------------------------------------------------------ def get_user(db: Session, user_id: str): return db.query(User).filter(User.user_id == user_id).first() @@ -56,6 +66,8 @@ def delete_user(db: Session, user_id: UserBase): return False +# Prices +# ------------------------------------------------------------------------------ def get_prices_query(filters: PriceFilter | None = None): """Useful for pagination.""" query = select(Price) @@ -76,6 +88,15 @@ def create_price(db: Session, price: PriceCreate, user: UserBase): return db_price +def set_price_location(db: Session, price: PriceBase, location: LocationBase): + price.location_id = location.id + db.commit() + db.refresh(price) + return price + + +# Proofs +# ------------------------------------------------------------------------------ def get_proof(db: Session, proof_id: int): return db.query(Proof).filter(Proof.id == proof_id).first() @@ -140,3 +161,33 @@ def create_proof_file(file: UploadFile) -> tuple[str, str]: file_path = f"{current_dir_id_str}/{file_stem}{extension}" return (file_path, mimetype) + + +# Locations +# ------------------------------------------------------------------------------ +def get_location_by_osm_id_and_type( + db: Session, osm_id: int, osm_type: LocationOSMType +): + return ( + db.query(Location) + .filter(Location.osm_id == osm_id) + .filter(Location.osm_type == osm_type) + .first() + ) + + +def create_location(db: Session, location: LocationCreate): + db_location = Location(**location.model_dump()) + db.add(db_location) + db.commit() + db.refresh(db_location) + return db_location + + +def get_or_create_location(db: Session, location: LocationCreate): + db_location = get_location_by_osm_id_and_type( + db, osm_id=location.osm_id, osm_type=location.osm_type + ) + if not db_location: + db_location = create_location(db, location=location) + return db_location diff --git a/app/models.py b/app/models.py index 6ea42522..762c6645 100644 --- a/app/models.py +++ b/app/models.py @@ -45,6 +45,8 @@ class Location(Base): osm_lat = Column(Numeric(precision=11, scale=7)) osm_lon = Column(Numeric(precision=11, scale=7)) + prices: Mapped[list["Price"]] = relationship(back_populates="location") + created = Column(DateTime(timezone=True), server_default=func.now()) updated = Column(DateTime(timezone=True), onupdate=func.now()) @@ -76,6 +78,8 @@ class Price(Base): location_osm_id = Column(BigInteger, index=True) location_osm_type = Column(ChoiceType(LocationOSMType)) + location_id: Mapped[int] = mapped_column(ForeignKey("locations.id"), nullable=True) + location: Mapped[Location] = relationship(back_populates="prices") date = Column(Date) diff --git a/app/schemas.py b/app/schemas.py index 63287a75..edb22846 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -62,6 +62,7 @@ def serialize_currency(self, currency: Currency, _info): class PriceBase(PriceCreate): + location_id: int | None # owner: str created: datetime diff --git a/app/tasks.py b/app/tasks.py new file mode 100644 index 00000000..8cc08c03 --- /dev/null +++ b/app/tasks.py @@ -0,0 +1,15 @@ +from sqlalchemy.orm import Session + +from app import crud +from app.schemas import LocationCreate, PriceBase + + +def create_price_location(db: Session, price: PriceBase): + if price.location_osm_id and price.location_osm_type: + # get or create the corresponding location + location = LocationCreate( + osm_id=price.location_osm_id, osm_type=price.location_osm_type + ) + db_location = crud.get_or_create_location(db, location=location) + # link the location to the price + crud.set_price_location(db, price=price, location=db_location)