Skip to content

Commit

Permalink
feat: on Price create, create (or get) Location, and link them (#36)
Browse files Browse the repository at this point in the history
* On Price create, background task to create Location

* Create relationship between Price & Location

* Link new or existing location to new price

* Prices: return location_id
  • Loading branch information
raphodn authored Nov 21, 2023
1 parent 32562ad commit 003de11
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Add Price.location relationship
Revision ID: 1c8431a64d3a
Revises: b3b951e016d0
Create Date: 2023-11-16 15:59:03.881443
"""
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "1c8431a64d3a"
down_revision: Union[str, None] = "b3b951e016d0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("prices", sa.Column("location_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "prices", "locations", ["location_id"], ["id"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "prices", type_="foreignkey")
op.drop_column("prices", "location_id")
# ### end Alembic commands ###
14 changes: 12 additions & 2 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from typing import Annotated

import requests
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
Expand All @@ -14,7 +22,7 @@
from openfoodfacts.utils import get_logger
from sqlalchemy.orm import Session

from app import crud, schemas
from app import crud, schemas, tasks
from app.config import settings
from app.db import session
from app.utils import init_sentry
Expand Down Expand Up @@ -136,6 +144,7 @@ async def get_price(
@app.post("/prices", response_model=schemas.PriceBase)
async def create_price(
price: schemas.PriceCreate,
background_tasks: BackgroundTasks,
current_user: schemas.UserBase = Depends(get_current_user),
db: Session = Depends(get_db),
):
Expand All @@ -162,6 +171,7 @@ async def create_price(
detail="Proof does not belong to current user",
)
db_price = crud.create_price(db, price=price, user=current_user)
background_tasks.add_task(tasks.create_price_location, db, db_price)
return db_price


Expand Down
59 changes: 55 additions & 4 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@
from sqlalchemy.sql import func

from app import config
from app.models import Price, Proof, User
from app.schemas import PriceCreate, PriceFilter, UserBase


from app.enums import LocationOSMType
from app.models import Location, Price, Proof, User
from app.schemas import (
LocationBase,
LocationCreate,
PriceBase,
PriceCreate,
PriceFilter,
UserBase,
)


# Users
# ------------------------------------------------------------------------------
def get_user(db: Session, user_id: str):
return db.query(User).filter(User.user_id == user_id).first()

Expand Down Expand Up @@ -56,6 +66,8 @@ def delete_user(db: Session, user_id: UserBase):
return False


# Prices
# ------------------------------------------------------------------------------
def get_prices_query(filters: PriceFilter | None = None):
"""Useful for pagination."""
query = select(Price)
Expand All @@ -76,6 +88,15 @@ def create_price(db: Session, price: PriceCreate, user: UserBase):
return db_price


def set_price_location(db: Session, price: PriceBase, location: LocationBase):
price.location_id = location.id
db.commit()
db.refresh(price)
return price


# Proofs
# ------------------------------------------------------------------------------
def get_proof(db: Session, proof_id: int):
return db.query(Proof).filter(Proof.id == proof_id).first()

Expand Down Expand Up @@ -140,3 +161,33 @@ def create_proof_file(file: UploadFile) -> tuple[str, str]:

file_path = f"{current_dir_id_str}/{file_stem}{extension}"
return (file_path, mimetype)


# Locations
# ------------------------------------------------------------------------------
def get_location_by_osm_id_and_type(
db: Session, osm_id: int, osm_type: LocationOSMType
):
return (
db.query(Location)
.filter(Location.osm_id == osm_id)
.filter(Location.osm_type == osm_type)
.first()
)


def create_location(db: Session, location: LocationCreate):
db_location = Location(**location.model_dump())
db.add(db_location)
db.commit()
db.refresh(db_location)
return db_location


def get_or_create_location(db: Session, location: LocationCreate):
db_location = get_location_by_osm_id_and_type(
db, osm_id=location.osm_id, osm_type=location.osm_type
)
if not db_location:
db_location = create_location(db, location=location)
return db_location
4 changes: 4 additions & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Location(Base):
osm_lat = Column(Numeric(precision=11, scale=7))
osm_lon = Column(Numeric(precision=11, scale=7))

prices: Mapped[list["Price"]] = relationship(back_populates="location")

created = Column(DateTime(timezone=True), server_default=func.now())
updated = Column(DateTime(timezone=True), onupdate=func.now())

Expand Down Expand Up @@ -76,6 +78,8 @@ class Price(Base):

location_osm_id = Column(BigInteger, index=True)
location_osm_type = Column(ChoiceType(LocationOSMType))
location_id: Mapped[int] = mapped_column(ForeignKey("locations.id"), nullable=True)
location: Mapped[Location] = relationship(back_populates="prices")

date = Column(Date)

Expand Down
1 change: 1 addition & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def serialize_currency(self, currency: Currency, _info):


class PriceBase(PriceCreate):
location_id: int | None
# owner: str
created: datetime

Expand Down
15 changes: 15 additions & 0 deletions app/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sqlalchemy.orm import Session

from app import crud
from app.schemas import LocationCreate, PriceBase


def create_price_location(db: Session, price: PriceBase):
if price.location_osm_id and price.location_osm_type:
# get or create the corresponding location
location = LocationCreate(
osm_id=price.location_osm_id, osm_type=price.location_osm_type
)
db_location = crud.get_or_create_location(db, location=location)
# link the location to the price
crud.set_price_location(db, price=price, location=db_location)

0 comments on commit 003de11

Please sign in to comment.