Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save Gemini prediction in price_tag_predictions table #630

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from open_prices.api.utils import get_source_from_request
from open_prices.common.authentication import CustomAuthentication
from open_prices.common.constants import PriceTagStatus
from open_prices.common.gemini import handle_bulk_labels
from open_prices.proofs.ml import extract_from_price_tags
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import store_file

Expand Down Expand Up @@ -125,7 +125,7 @@ def upload(self, request: Request) -> Response:
def process_with_gemini(self, request: Request) -> Response:
files = request.FILES.getlist("files")
sample_files = [PIL.Image.open(file.file) for file in files]
res = handle_bulk_labels(sample_files)
res = extract_from_price_tags(sample_files)
return Response(res, status=status.HTTP_200_OK)


Expand Down
141 changes: 0 additions & 141 deletions open_prices/common/gemini.py
Original file line number Diff line number Diff line change
@@ -1,141 +0,0 @@
import enum
import json

import google.generativeai as genai
import typing_extensions as typing
from django.conf import settings

genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY)
model = genai.GenerativeModel(model_name="gemini-1.5-flash")


# TODO: what about orther categories ?
class Products(enum.Enum):
OTHER = "other"
APPLES = "en:apples"
APRICOTS = "en:apricots"
ARTICHOKES = "en:artichokes"
ASPARAGUS = "en:asparagus"
AUBERGINES = "en:aubergines"
AVOCADOS = "en:avocados"
BANANAS = "en:bananas"
BEET = "en:beet"
BERRIES = "en:berries"
BLACKBERRIES = "en:blackberries"
BLUEBERRIES = "en:blueberries"
BOK_CHOY = "en:bok-choy"
BROCCOLI = "en:broccoli"
CABBAGES = "en:cabbages"
CARROTS = "en:carrots"
CAULIFLOWERS = "en:cauliflowers"
CELERY = "en:celery"
CELERY_STALK = "en:celery-stalk"
CEP_MUSHROOMS = "en:cep-mushrooms"
CHANTERELLES = "en:chanterelles"
CHERRIES = "en:cherries"
CHERRY_TOMATOES = "en:cherry-tomatoes"
CHICKPEAS = "en:chickpeas"
CHIVES = "en:chives"
CLEMENTINES = "en:clementines"
COCONUTS = "en:coconuts"
CRANBERRIES = "en:cranberries"
CUCUMBERS = "en:cucumbers"
DATES = "en:dates"
ENDIVES = "en:endives"
FIGS = "en:figs"
GARLIC = "en:garlic"
GINGER = "en:ginger"
GRAPEFRUITS = "en:grapefruits"
GRAPES = "en:grapes"
GREEN_BEANS = "en:green-beans"
KIWIS = "en:kiwis"
KAKIS = "en:kakis"
LEEKS = "en:leeks"
LEMONS = "en:lemons"
LETTUCES = "en:lettuces"
LIMES = "en:limes"
LYCHEES = "en:lychees"
MANDARIN_ORANGES = "en:mandarin-oranges"
MANGOES = "en:mangoes"
MELONS = "en:melons"
MUSHROOMS = "en:mushrooms"
NECTARINES = "en:nectarines"
ONIONS = "en:onions"
ORANGES = "en:oranges"
PAPAYAS = "en:papayas"
PASSION_FRUITS = "en:passion-fruits"
PEACHES = "en:peaches"
PEARS = "en:pears"
PEAS = "en:peas"
PEPPERS = "en:peppers"
PINEAPPLE = "en:pineapple"
PLUMS = "en:plums"
POMEGRANATES = "en:pomegranates"
POMELOS = "en:pomelos"
POTATOES = "en:potatoes"
PUMPKINS = "en:pumpkins"
RADISHES = "en:radishes"
RASPBERRIES = "en:raspberries"
RHUBARBS = "en:rhubarbs"
SCALLIONS = "en:scallions"
SHALLOTS = "en:shallots"
SPINACHS = "en:spinachs"
SPROUTS = "en:sprouts"
STRAWBERRIES = "en:strawberries"
TOMATOES = "en:tomatoes"
TURNIP = "en:turnip"
WATERMELONS = "en:watermelons"
WALNUTS = "en:walnuts"
ZUCCHINI = "en:zucchini"


# TODO: what about other origins ?
class Origin(enum.Enum):
FRANCE = "en:france"
ITALY = "en:italy"
SPAIN = "en:spain"
POLAND = "en:poland"
CHINA = "en:china"
BELGIUM = "en:belgium"
MOROCCO = "en:morocco"
PERU = "en:peru"
PORTUGAL = "en:portugal"
MEXICO = "en:mexico"
OTHER = "other"
UNKNOWN = "unknown"


class Unit(enum.Enum):
KILOGRAM = "KILOGRAM"
UNIT = "UNIT"


class Label(typing.TypedDict):
product: Products
price: float
origin: Origin
unit: Unit
organic: bool
barcode: str


class Labels(typing.TypedDict):
labels: list[Label]


def handle_bulk_labels(images):
response = model.generate_content(
[
"Here are "
+ str(len(images))
+ " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of "
+ str(len(images))
+ " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string"
]
+ images,
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Labels
),
)
vals = json.loads(response.text)
return vals
8 changes: 6 additions & 2 deletions open_prices/proofs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
PROOF_PREDICTION_OBJECT_DETECTION_TYPE = "OBJECT_DETECTION"
PROOF_PREDICTION_CLASSIFICATION_TYPE = "CLASSIFICATION"
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE = "RECEIPT_EXTRACTION"
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION"
PROOF_PREDICTION_LIST = [
PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
PROOF_PREDICTION_CLASSIFICATION_TYPE,
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE,
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE,
]

PROOF_TYPE_CHOICES = [(key, key) for key in PROOF_PREDICTION_LIST]

PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION"

PRICE_TAG_PREDICTION_TYPE_CHOICES = [
(PRICE_TAG_EXTRACTION_TYPE, PRICE_TAG_EXTRACTION_TYPE)
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Generated by Django 5.1.4 on 2024-12-17 14:01

import django.db.models.deletion
import django.utils.timezone
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("proofs", "0007_pricetag"),
]

operations = [
migrations.AlterField(
model_name="proofprediction",
name="type",
field=models.CharField(
choices=[
("OBJECT_DETECTION", "OBJECT_DETECTION"),
("CLASSIFICATION", "CLASSIFICATION"),
("RECEIPT_EXTRACTION", "RECEIPT_EXTRACTION"),
],
max_length=20,
verbose_name="The type of the prediction",
),
),
migrations.CreateModel(
name="PriceTagPrediction",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"type",
models.CharField(
choices=[("PRICE_TAG_EXTRACTION", "PRICE_TAG_EXTRACTION")],
help_text="The type of the prediction",
max_length=20,
),
),
(
"model_name",
models.CharField(
help_text="The name of the model that generated the prediction",
max_length=30,
),
),
(
"model_version",
models.CharField(
help_text="The specific version of the model that generated the prediction",
max_length=30,
),
),
(
"created",
models.DateTimeField(
default=django.utils.timezone.now,
help_text="When the prediction was created in DB",
),
),
(
"data",
models.JSONField(
default=dict,
help_text="a dict representing the data of the prediction. This field is model-specific.",
),
),
(
"price_tag",
models.ForeignKey(
help_text="The price tag this prediction belongs to",
on_delete=django.db.models.deletion.CASCADE,
related_name="predictions",
to="proofs.pricetag",
),
),
],
options={
"verbose_name": "Price Tag Prediction",
"verbose_name_plural": "Price Tag Predictions",
"db_table": "price_tag_predictions",
},
),
]
Loading
Loading