From bac7c3a590e4567a4a91a2ec7fd84a47014ca416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 17 Dec 2024 15:18:18 +0100 Subject: [PATCH] feat: save Gemini prediction in price_tag_predictions table --- open_prices/api/proofs/views.py | 4 +- open_prices/common/gemini.py | 141 ----------- open_prices/proofs/constants.py | 8 +- ...proofprediction_type_pricetagprediction.py | 91 ++++++++ open_prices/proofs/ml.py | 220 +++++++++++++++++- open_prices/proofs/models.py | 41 ++++ open_prices/proofs/tests.py | 8 +- 7 files changed, 358 insertions(+), 155 deletions(-) create mode 100644 open_prices/proofs/migrations/0008_alter_proofprediction_type_pricetagprediction.py diff --git a/open_prices/api/proofs/views.py b/open_prices/api/proofs/views.py index 267e3abf..78aa80ca 100644 --- a/open_prices/api/proofs/views.py +++ b/open_prices/api/proofs/views.py @@ -24,7 +24,7 @@ from open_prices.api.utils import get_source_from_request from open_prices.common.authentication import CustomAuthentication from open_prices.common.constants import PriceTagStatus -from open_prices.common.gemini import handle_bulk_labels +from open_prices.proofs.ml import extract_from_price_tags from open_prices.proofs.models import PriceTag, Proof from open_prices.proofs.utils import store_file @@ -125,7 +125,7 @@ def upload(self, request: Request) -> Response: def process_with_gemini(self, request: Request) -> Response: files = request.FILES.getlist("files") sample_files = [PIL.Image.open(file.file) for file in files] - res = handle_bulk_labels(sample_files) + res = extract_from_price_tags(sample_files) return Response(res, status=status.HTTP_200_OK) diff --git a/open_prices/common/gemini.py b/open_prices/common/gemini.py index ef4efc86..e69de29b 100644 --- a/open_prices/common/gemini.py +++ b/open_prices/common/gemini.py @@ -1,141 +0,0 @@ -import enum -import json - -import google.generativeai as genai -import typing_extensions as typing -from django.conf import settings - -genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY) -model = genai.GenerativeModel(model_name="gemini-1.5-flash") - - -# TODO: what about orther categories ? -class Products(enum.Enum): - OTHER = "other" - APPLES = "en:apples" - APRICOTS = "en:apricots" - ARTICHOKES = "en:artichokes" - ASPARAGUS = "en:asparagus" - AUBERGINES = "en:aubergines" - AVOCADOS = "en:avocados" - BANANAS = "en:bananas" - BEET = "en:beet" - BERRIES = "en:berries" - BLACKBERRIES = "en:blackberries" - BLUEBERRIES = "en:blueberries" - BOK_CHOY = "en:bok-choy" - BROCCOLI = "en:broccoli" - CABBAGES = "en:cabbages" - CARROTS = "en:carrots" - CAULIFLOWERS = "en:cauliflowers" - CELERY = "en:celery" - CELERY_STALK = "en:celery-stalk" - CEP_MUSHROOMS = "en:cep-mushrooms" - CHANTERELLES = "en:chanterelles" - CHERRIES = "en:cherries" - CHERRY_TOMATOES = "en:cherry-tomatoes" - CHICKPEAS = "en:chickpeas" - CHIVES = "en:chives" - CLEMENTINES = "en:clementines" - COCONUTS = "en:coconuts" - CRANBERRIES = "en:cranberries" - CUCUMBERS = "en:cucumbers" - DATES = "en:dates" - ENDIVES = "en:endives" - FIGS = "en:figs" - GARLIC = "en:garlic" - GINGER = "en:ginger" - GRAPEFRUITS = "en:grapefruits" - GRAPES = "en:grapes" - GREEN_BEANS = "en:green-beans" - KIWIS = "en:kiwis" - KAKIS = "en:kakis" - LEEKS = "en:leeks" - LEMONS = "en:lemons" - LETTUCES = "en:lettuces" - LIMES = "en:limes" - LYCHEES = "en:lychees" - MANDARIN_ORANGES = "en:mandarin-oranges" - MANGOES = "en:mangoes" - MELONS = "en:melons" - MUSHROOMS = "en:mushrooms" - NECTARINES = "en:nectarines" - ONIONS = "en:onions" - ORANGES = "en:oranges" - PAPAYAS = "en:papayas" - PASSION_FRUITS = "en:passion-fruits" - PEACHES = "en:peaches" - PEARS = "en:pears" - PEAS = "en:peas" - PEPPERS = "en:peppers" - PINEAPPLE = "en:pineapple" - PLUMS = "en:plums" - POMEGRANATES = "en:pomegranates" - POMELOS = "en:pomelos" - POTATOES = "en:potatoes" - PUMPKINS = "en:pumpkins" - RADISHES = "en:radishes" - RASPBERRIES = "en:raspberries" - RHUBARBS = "en:rhubarbs" - SCALLIONS = "en:scallions" - SHALLOTS = "en:shallots" - SPINACHS = "en:spinachs" - SPROUTS = "en:sprouts" - STRAWBERRIES = "en:strawberries" - TOMATOES = "en:tomatoes" - TURNIP = "en:turnip" - WATERMELONS = "en:watermelons" - WALNUTS = "en:walnuts" - ZUCCHINI = "en:zucchini" - - -# TODO: what about other origins ? -class Origin(enum.Enum): - FRANCE = "en:france" - ITALY = "en:italy" - SPAIN = "en:spain" - POLAND = "en:poland" - CHINA = "en:china" - BELGIUM = "en:belgium" - MOROCCO = "en:morocco" - PERU = "en:peru" - PORTUGAL = "en:portugal" - MEXICO = "en:mexico" - OTHER = "other" - UNKNOWN = "unknown" - - -class Unit(enum.Enum): - KILOGRAM = "KILOGRAM" - UNIT = "UNIT" - - -class Label(typing.TypedDict): - product: Products - price: float - origin: Origin - unit: Unit - organic: bool - barcode: str - - -class Labels(typing.TypedDict): - labels: list[Label] - - -def handle_bulk_labels(images): - response = model.generate_content( - [ - "Here are " - + str(len(images)) - + " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of " - + str(len(images)) - + " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string" - ] - + images, - generation_config=genai.GenerationConfig( - response_mime_type="application/json", response_schema=Labels - ), - ) - vals = json.loads(response.text) - return vals diff --git a/open_prices/proofs/constants.py b/open_prices/proofs/constants.py index b3f4a461..48e9cf1a 100644 --- a/open_prices/proofs/constants.py +++ b/open_prices/proofs/constants.py @@ -18,12 +18,16 @@ PROOF_PREDICTION_OBJECT_DETECTION_TYPE = "OBJECT_DETECTION" PROOF_PREDICTION_CLASSIFICATION_TYPE = "CLASSIFICATION" PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE = "RECEIPT_EXTRACTION" -PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION" PROOF_PREDICTION_LIST = [ PROOF_PREDICTION_OBJECT_DETECTION_TYPE, PROOF_PREDICTION_CLASSIFICATION_TYPE, PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE, - PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE, ] PROOF_TYPE_CHOICES = [(key, key) for key in PROOF_PREDICTION_LIST] + +PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION" + +PRICE_TAG_PREDICTION_TYPE_CHOICES = [ + (PRICE_TAG_EXTRACTION_TYPE, PRICE_TAG_EXTRACTION_TYPE) +] diff --git a/open_prices/proofs/migrations/0008_alter_proofprediction_type_pricetagprediction.py b/open_prices/proofs/migrations/0008_alter_proofprediction_type_pricetagprediction.py new file mode 100644 index 00000000..08895799 --- /dev/null +++ b/open_prices/proofs/migrations/0008_alter_proofprediction_type_pricetagprediction.py @@ -0,0 +1,91 @@ +# Generated by Django 5.1.4 on 2024-12-17 14:01 + +import django.db.models.deletion +import django.utils.timezone +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("proofs", "0007_pricetag"), + ] + + operations = [ + migrations.AlterField( + model_name="proofprediction", + name="type", + field=models.CharField( + choices=[ + ("OBJECT_DETECTION", "OBJECT_DETECTION"), + ("CLASSIFICATION", "CLASSIFICATION"), + ("RECEIPT_EXTRACTION", "RECEIPT_EXTRACTION"), + ], + max_length=20, + verbose_name="The type of the prediction", + ), + ), + migrations.CreateModel( + name="PriceTagPrediction", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "type", + models.CharField( + choices=[("PRICE_TAG_EXTRACTION", "PRICE_TAG_EXTRACTION")], + help_text="The type of the prediction", + max_length=20, + ), + ), + ( + "model_name", + models.CharField( + help_text="The name of the model that generated the prediction", + max_length=30, + ), + ), + ( + "model_version", + models.CharField( + help_text="The specific version of the model that generated the prediction", + max_length=30, + ), + ), + ( + "created", + models.DateTimeField( + default=django.utils.timezone.now, + help_text="When the prediction was created in DB", + ), + ), + ( + "data", + models.JSONField( + default=dict, + help_text="a dict representing the data of the prediction. This field is model-specific.", + ), + ), + ( + "price_tag", + models.ForeignKey( + help_text="The price tag this prediction belongs to", + on_delete=django.db.models.deletion.CASCADE, + related_name="predictions", + to="proofs.pricetag", + ), + ), + ], + options={ + "verbose_name": "Price Tag Prediction", + "verbose_name_plural": "Price Tag Predictions", + "db_table": "price_tag_predictions", + }, + ), + ] diff --git a/open_prices/proofs/ml.py b/open_prices/proofs/ml.py index abc8674b..fef2e9af 100644 --- a/open_prices/proofs/ml.py +++ b/open_prices/proofs/ml.py @@ -1,13 +1,17 @@ +import enum +import json import logging from pathlib import Path +import google.generativeai as genai +import typing_extensions as typing from django.conf import settings from openfoodfacts.ml.image_classification import ImageClassifier from openfoodfacts.ml.object_detection import ObjectDetectionRawResult, ObjectDetector from PIL import Image from . import constants -from .models import PriceTag, Proof, ProofPrediction +from .models import PriceTag, PriceTagPrediction, Proof, ProofPrediction logger = logging.getLogger(__name__) @@ -28,6 +32,143 @@ PRICE_TAG_DETECTOR_MODEL_VERSION = "price_tag_detection-1.0" PRICE_TAG_DETECTOR_TRITON_VERSION = "1" PRICE_TAG_DETECTOR_IMAGE_SIZE = 960 +GEMINI_MODEL_NAME = "gemini" +GEMINI_MODEL_VERSION = "gemini-1.5-flash" + +genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY) +model = genai.GenerativeModel(model_name=GEMINI_MODEL_VERSION) + + +# TODO: what about other categories? +class Products(enum.Enum): + OTHER = "other" + APPLES = "en:apples" + APRICOTS = "en:apricots" + ARTICHOKES = "en:artichokes" + ASPARAGUS = "en:asparagus" + AUBERGINES = "en:aubergines" + AVOCADOS = "en:avocados" + BANANAS = "en:bananas" + BEET = "en:beet" + BERRIES = "en:berries" + BLACKBERRIES = "en:blackberries" + BLUEBERRIES = "en:blueberries" + BOK_CHOY = "en:bok-choy" + BROCCOLI = "en:broccoli" + CABBAGES = "en:cabbages" + CARROTS = "en:carrots" + CAULIFLOWERS = "en:cauliflowers" + CELERY = "en:celery" + CELERY_STALK = "en:celery-stalk" + CEP_MUSHROOMS = "en:cep-mushrooms" + CHANTERELLES = "en:chanterelles" + CHERRIES = "en:cherries" + CHERRY_TOMATOES = "en:cherry-tomatoes" + CHICKPEAS = "en:chickpeas" + CHIVES = "en:chives" + CLEMENTINES = "en:clementines" + COCONUTS = "en:coconuts" + CRANBERRIES = "en:cranberries" + CUCUMBERS = "en:cucumbers" + DATES = "en:dates" + ENDIVES = "en:endives" + FIGS = "en:figs" + GARLIC = "en:garlic" + GINGER = "en:ginger" + GRAPEFRUITS = "en:grapefruits" + GRAPES = "en:grapes" + GREEN_BEANS = "en:green-beans" + KIWIS = "en:kiwis" + KAKIS = "en:kakis" + LEEKS = "en:leeks" + LEMONS = "en:lemons" + LETTUCES = "en:lettuces" + LIMES = "en:limes" + LYCHEES = "en:lychees" + MANDARIN_ORANGES = "en:mandarin-oranges" + MANGOES = "en:mangoes" + MELONS = "en:melons" + MUSHROOMS = "en:mushrooms" + NECTARINES = "en:nectarines" + ONIONS = "en:onions" + ORANGES = "en:oranges" + PAPAYAS = "en:papayas" + PASSION_FRUITS = "en:passion-fruits" + PEACHES = "en:peaches" + PEARS = "en:pears" + PEAS = "en:peas" + PEPPERS = "en:peppers" + PINEAPPLE = "en:pineapple" + PLUMS = "en:plums" + POMEGRANATES = "en:pomegranates" + POMELOS = "en:pomelos" + POTATOES = "en:potatoes" + PUMPKINS = "en:pumpkins" + RADISHES = "en:radishes" + RASPBERRIES = "en:raspberries" + RHUBARBS = "en:rhubarbs" + SCALLIONS = "en:scallions" + SHALLOTS = "en:shallots" + SPINACHS = "en:spinachs" + SPROUTS = "en:sprouts" + STRAWBERRIES = "en:strawberries" + TOMATOES = "en:tomatoes" + TURNIP = "en:turnip" + WATERMELONS = "en:watermelons" + WALNUTS = "en:walnuts" + ZUCCHINI = "en:zucchini" + + +# TODO: what about other origins? +class Origin(enum.Enum): + FRANCE = "en:france" + ITALY = "en:italy" + SPAIN = "en:spain" + POLAND = "en:poland" + CHINA = "en:china" + BELGIUM = "en:belgium" + MOROCCO = "en:morocco" + PERU = "en:peru" + PORTUGAL = "en:portugal" + MEXICO = "en:mexico" + OTHER = "other" + UNKNOWN = "unknown" + + +class Unit(enum.Enum): + KILOGRAM = "KILOGRAM" + UNIT = "UNIT" + + +class Label(typing.TypedDict): + product: Products + price: float + origin: Origin + unit: Unit + organic: bool + barcode: str + + +class Labels(typing.TypedDict): + labels: list[Label] + + +def extract_from_price_tags(images: Image.Image) -> Labels: + """Extract price tag information from a list of images.""" + response = model.generate_content( + [ + "Here are " + + str(len(images)) + + " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of " + + str(len(images)) + + " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string" + ] + + images, + generation_config=genai.GenerationConfig( + response_mime_type="application/json", response_schema=Labels + ), + ) + return json.loads(response.text) def predict_proof_type( @@ -86,8 +227,55 @@ def detect_price_tags( ) +def run_and_save_price_tag_extraction( + price_tags: list[PriceTag], proof: Proof +) -> list[PriceTagPrediction]: + """Extract information from price tags using the Gemini model and save the + predictions in the database. + + :param price_tags: the list of PriceTag instances to extract information + from + :param proof: the Proof instance associated with the price tags + :return: the list of PriceTagPrediction instances created + """ + if proof.file_path_full is None or not Path(proof.file_path_full).exists(): + logger.error("Proof file not found: %s", proof.file_path_full) + return [] + + cropped_images = [] + for price_tag in price_tags: + y_min, x_min, y_max, x_max = price_tag.bounding_box + image = Image.open(proof.file_path_full) + (left, right, top, bottom) = ( + x_min * image.width, + x_max * image.width, + y_min * image.height, + y_max * image.height, + ) + cropped_image = image.crop((left, top, right, bottom)) + cropped_images.append(cropped_image) + + labels = extract_from_price_tags(cropped_images) + + predictions = [] + for price_tag, label in zip(price_tags, labels["labels"]): + prediction = PriceTagPrediction.objects.create( + price_tag=price_tag, + type=constants.PRICE_TAG_EXTRACTION_TYPE, + model_name=GEMINI_MODEL_NAME, + model_version=GEMINI_MODEL_VERSION, + data=label, + ) + predictions.append(prediction) + + return predictions + + def create_price_tags_from_proof_prediction( - proof: Proof, proof_prediction: ProofPrediction, threshold: float = 0.5 + proof: Proof, + proof_prediction: ProofPrediction, + threshold: float = 0.5, + run_extraction: bool = True, ) -> list[PriceTag]: """Create price tags from a proof prediction containing price tag object detections. @@ -97,6 +285,8 @@ def create_price_tags_from_proof_prediction( price tag detections :param threshold: the minimum confidence threshold for a detection to be considered valid, defaults to 0.5 + :param run_extraction: whether to run the price tag extraction model on the + detected price tags, defaults to True :return: the list of PriceTag instances created """ if proof_prediction.model_name != PRICE_TAG_DETECTOR_MODEL_NAME: @@ -118,11 +308,15 @@ def create_price_tags_from_proof_prediction( updated_by=None, ) created.append(price_tag) + + if run_extraction: + run_and_save_price_tag_extraction(created, proof) + return created def run_and_save_price_tag_detection( - image: Image, proof: Proof, overwrite: bool = False + image: Image, proof: Proof, overwrite: bool = False, run_extraction: bool = True ) -> ProofPrediction | None: """Run the price tag object detection model and save the prediction in ProofPrediction table. @@ -131,6 +325,8 @@ def run_and_save_price_tag_detection( :param proof: the Proof instance to associate the ProofPrediction with :param overwrite: whether to overwrite existing prediction, defaults to False + :param run_extraction: whether to run the price tag extraction model on the + detected price tags, defaults to True :return: the ProofPrediction instance created, or None if the prediction already exists and overwrite is False """ @@ -156,7 +352,9 @@ def run_and_save_price_tag_detection( "Creating price tags from existing prediction for proof %s", proof.id, ) - create_price_tags_from_proof_prediction(proof, proof_prediction) + create_price_tags_from_proof_prediction( + proof, proof_prediction, run_extraction=run_extraction + ) return None result = detect_price_tags(image) @@ -175,7 +373,9 @@ def run_and_save_price_tag_detection( value=None, max_confidence=max_confidence, ) - create_price_tags_from_proof_prediction(proof, proof_prediction) + create_price_tags_from_proof_prediction( + proof, proof_prediction, run_extraction=run_extraction + ) return proof_prediction @@ -228,7 +428,9 @@ def run_and_save_proof_type_prediction( ) -def run_and_save_proof_prediction(proof_id: int) -> None: +def run_and_save_proof_prediction( + proof_id: int, run_price_tag_extraction: bool = True +) -> None: """Run all ML models on a specific proof, and save the predictions in DB. Currently, the following models are run: @@ -237,6 +439,8 @@ def run_and_save_proof_prediction(proof_id: int) -> None: - price tag detection model (objecct detector) :param proof_id: the ID of the proof to be classified + :param run_price_tag_extraction: whether to run the price tag extraction + model on the detected price tags, defaults to True """ proof = Proof.objects.filter(id=proof_id).first() if not proof: @@ -255,4 +459,6 @@ def run_and_save_proof_prediction(proof_id: int) -> None: image = Image.open(file_path_full) run_and_save_proof_type_prediction(image, proof) - run_and_save_price_tag_detection(image, proof) + run_and_save_price_tag_detection( + image, proof, run_extraction=run_price_tag_extraction + ) diff --git a/open_prices/proofs/models.py b/open_prices/proofs/models.py index 901eb12e..4a0f4de2 100644 --- a/open_prices/proofs/models.py +++ b/open_prices/proofs/models.py @@ -536,3 +536,44 @@ def clean(self, *args, **kwargs): def save(self, *args, **kwargs): self.full_clean() super().save(*args, **kwargs) + + +class PriceTagPrediction(models.Model): + """A machine learning prediction for a price tag.""" + + price_tag = models.ForeignKey( + PriceTag, + on_delete=models.CASCADE, + related_name="predictions", + help_text="The price tag this prediction belongs to", + ) + type = models.CharField( + max_length=20, + choices=proof_constants.PRICE_TAG_PREDICTION_TYPE_CHOICES, + help_text="The type of the prediction", + ) + model_name = models.CharField( + max_length=30, + help_text="The name of the model that generated the prediction", + ) + model_version = models.CharField( + max_length=30, + help_text="The specific version of the model that generated the prediction", + ) + created = models.DateTimeField( + default=timezone.now, help_text="When the prediction was created in DB" + ) + data = models.JSONField( + null=False, + blank=False, + help_text="a dict representing the data of the prediction. This field is model-specific.", + default=dict, + ) + + class Meta: + db_table = "price_tag_predictions" + verbose_name = "Price Tag Prediction" + verbose_name_plural = "Price Tag Predictions" + + def __str__(self): + return f"{self.model_name} - {self.model_version} - {self.price_tag}" diff --git a/open_prices/proofs/tests.py b/open_prices/proofs/tests.py index ac577387..f02d3aec 100644 --- a/open_prices/proofs/tests.py +++ b/open_prices/proofs/tests.py @@ -429,7 +429,9 @@ def test_run_and_save_proof_prediction_proof(self): return_value=detect_price_tags_response, ) as mock_detect_price_tags, ): - run_and_save_proof_prediction(proof.id) + run_and_save_proof_prediction( + proof.id, run_price_tag_extraction=False + ) mock_predict_proof_type.assert_called_once() mock_detect_price_tags.assert_called_once() @@ -524,7 +526,7 @@ def test_run_and_save_price_tag_detection_already_exists(self): ] }, ) - result = run_and_save_price_tag_detection(image, proof) + result = run_and_save_price_tag_detection(image, proof, run_extraction=False) self.assertIsNone(result) price_tags = PriceTag.objects.filter(proof=proof).all() self.assertEqual(len(price_tags), 2) @@ -560,7 +562,7 @@ def create_price_tags_from_proof_prediction(self): ) before = timezone.now() results = create_price_tags_from_proof_prediction( - proof, proof_prediction, threshold=0.4 + proof, proof_prediction, threshold=0.4, run_extraction=False ) after = timezone.now() self.assertEqual(len(results), 2)