Skip to content

Commit

Permalink
feat: add user.price_count to keep track of number of prices (#143)
Browse files Browse the repository at this point in the history
* New field User.price_count

* New task to increment user price_count of price create
  • Loading branch information
raphodn authored Jan 14, 2024
1 parent 502539c commit 0f13566
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Add user price_count field
Revision ID: 868640c5012e
Revises: 13bb81a35e60
Create Date: 2024-01-14 13:26:10.269120
"""
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "868640c5012e"
down_revision: Union[str, None] = "13bb81a35e60"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users",
sa.Column("price_count", sa.Integer(), server_default="0", nullable=False),
)
op.create_index(
op.f("ix_users_price_count"), "users", ["price_count"], unique=False
)
# Set the price_count to the number of prices for each user
op.execute(
"""
UPDATE users
SET price_count = (
SELECT COUNT(*)
FROM prices
WHERE prices.owner = users.user_id
)
"""
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_users_price_count"), table_name="users")
op.drop_column("users", "price_count")
# ### end Alembic commands ###
5 changes: 3 additions & 2 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,9 @@ def create_price(
)

db_price = crud.create_price(db, price=price, user=current_user)
background_tasks.add_task(tasks.create_price_product, db, db_price)
background_tasks.add_task(tasks.create_price_location, db, db_price)
background_tasks.add_task(tasks.create_price_product, db, price=db_price)
background_tasks.add_task(tasks.create_price_location, db, price=db_price)
background_tasks.add_task(tasks.increment_user_price_count, db, user=current_user)
return db_price


Expand Down
11 changes: 11 additions & 0 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ def update_user_last_used_field(db: Session, user: UserBase) -> UserBase | None:
return update_user(db, user, {"last_used": func.now()})


def increment_user_price_count(db: Session, user: UserBase):
"""Increment the price count of a user.
This is used to keep track of the number of prices linked to a user.
"""
user.price_count += 1
db.commit()
db.refresh(user)
return user


def delete_user(db: Session, user_id: UserBase):
db_user = get_user_by_user_id(db, user_id=user_id)
if db_user:
Expand Down
1 change: 1 addition & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class User(Base):
token = Column(String, unique=True, index=True)

last_used = Column(DateTime(timezone=True))
price_count = Column(Integer, nullable=False, server_default="0", index=True)

created = Column(DateTime(timezone=True), server_default=func.now())

Expand Down
1 change: 1 addition & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class UserBase(BaseModel):

user_id: str
token: str
price_count: int = 0


class ProductCreate(BaseModel):
Expand Down
96 changes: 39 additions & 57 deletions app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@

import tqdm
from openfoodfacts import DatasetType, Flavor, ProductDataset
from openfoodfacts.images import generate_image_url
from openfoodfacts.types import JSONType
from openfoodfacts.utils import get_logger
from sqlalchemy import or_, select
from sqlalchemy.orm import Session

from app import crud
from app.config import settings
from app.models import Product
from app.schemas import LocationCreate, PriceBase, ProductCreate
from app.schemas import LocationCreate, PriceBase, ProductCreate, UserBase
from app.utils import (
OFF_FIELDS,
fetch_location_openstreetmap_details,
fetch_product_openfoodfacts_details,
generate_openfoodfacts_main_image_url,
normalize_product_fields,
)

logger = get_logger(__name__)


# Users
# ------------------------------------------------------------------------------
def increment_user_price_count(db: Session, user: UserBase):
crud.increment_user_price_count(db, user=user)


# Products
# ------------------------------------------------------------------------------
def create_price_product(db: Session, price: PriceBase):
# The price may not have a product code, if it's the price of a
# barcode-less product
Expand All @@ -47,58 +54,6 @@ def create_price_product(db: Session, price: PriceBase):
crud.increment_product_price_count(db, product=db_product)


def create_price_location(db: Session, price: PriceBase):
if price.location_osm_id and price.location_osm_type:
# get or create the corresponding location
location = LocationCreate(
osm_id=price.location_osm_id, osm_type=price.location_osm_type
)
db_location, created = crud.get_or_create_location(
db, location=location, init_price_count=1
)
# link the location to the price
crud.set_price_location(db, price=price, location=db_location)
# fetch data from OpenStreetMap if created
if created:
location_openstreetmap_details = fetch_location_openstreetmap_details(
location=db_location
)
if location_openstreetmap_details:
crud.update_location(
db, location=db_location, update_dict=location_openstreetmap_details
)
else:
# Increment the price count of the location
crud.increment_location_price_count(db, location=db_location)


def generate_main_image_url(code: str, images: JSONType, lang: str) -> str | None:
"""Generate the URL of the main image of a product.
:param code: The code of the product
:param images: The images of the product
:param lang: The main language of the product
:return: The URL of the main image of the product or None if no image is
available.
"""
image_key = None
if f"front_{lang}" in images:
image_key = f"front_{lang}"
else:
for key in (k for k in images if k.startswith("front_")):
image_key = key
break

if image_key:
image_rev = images[image_key]["rev"]
image_id = f"{image_key}.{image_rev}.400"
return generate_image_url(
code, image_id=image_id, flavor=Flavor.off, environment=settings.environment
)

return None


def import_product_db(db: Session, batch_size: int = 1000):
"""Import from DB JSONL dump to insert/update product table.
Expand Down Expand Up @@ -155,7 +110,7 @@ def import_product_db(db: Session, batch_size: int = 1000):
item[key] = product[key] if key in product else None

item = normalize_product_fields(item)
item["image_url"] = generate_main_image_url(
item["image_url"] = generate_openfoodfacts_main_image_url(
product_code, images, product["lang"]
)
db.add(Product(**item))
Expand All @@ -164,7 +119,7 @@ def import_product_db(db: Session, batch_size: int = 1000):

else:
item = {key: product[key] if key in product else None for key in OFF_FIELDS}
item["image_url"] = generate_main_image_url(
item["image_url"] = generate_openfoodfacts_main_image_url(
product_code, images, product["lang"]
)
item = normalize_product_fields(item)
Expand All @@ -189,3 +144,30 @@ def import_product_db(db: Session, batch_size: int = 1000):
db.commit()
logger.info(f"Products: {added_count} added, {updated_count} updated")
buffer_len = 0


# Locations
# ------------------------------------------------------------------------------
def create_price_location(db: Session, price: PriceBase):
if price.location_osm_id and price.location_osm_type:
# get or create the corresponding location
location = LocationCreate(
osm_id=price.location_osm_id, osm_type=price.location_osm_type
)
db_location, created = crud.get_or_create_location(
db, location=location, init_price_count=1
)
# link the location to the price
crud.set_price_location(db, price=price, location=db_location)
# fetch data from OpenStreetMap if created
if created:
location_openstreetmap_details = fetch_location_openstreetmap_details(
location=db_location
)
if location_openstreetmap_details:
crud.update_location(
db, location=db_location, update_dict=location_openstreetmap_details
)
else:
# Increment the price count of the location
crud.increment_location_price_count(db, location=db_location)
30 changes: 30 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sentry_sdk
from openfoodfacts import API, APIVersion, Country, Flavor
from openfoodfacts.images import generate_image_url
from openfoodfacts.types import JSONType
from openfoodfacts.utils import get_logger
from OSMPythonTools.nominatim import Nominatim
Expand Down Expand Up @@ -75,6 +76,35 @@ def normalize_product_fields(product: JSONType) -> JSONType:
return product


def generate_openfoodfacts_main_image_url(
code: str, images: JSONType, lang: str
) -> str | None:
"""Generate the URL of the main image of a product.
:param code: The code of the product
:param images: The images of the product
:param lang: The main language of the product
:return: The URL of the main image of the product or None if no image is
available.
"""
image_key = None
if f"front_{lang}" in images:
image_key = f"front_{lang}"
else:
for key in (k for k in images if k.startswith("front_")):
image_key = key
break

if image_key:
image_rev = images[image_key]["rev"]
image_id = f"{image_key}.{image_rev}.400"
return generate_image_url(
code, image_id=image_id, flavor=Flavor.off, environment=settings.environment
)

return None


def fetch_product_openfoodfacts_details(product: ProductBase) -> JSONType | None:
product = {}
try:
Expand Down

0 comments on commit 0f13566

Please sign in to comment.