Skip to content

Commit

Permalink
fix: fix script that runs ML models for proofs without predictions (#627
Browse files Browse the repository at this point in the history
)
  • Loading branch information
raphael0202 authored Dec 16, 2024
1 parent 59f9cd6 commit c825d2c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 27 deletions.
71 changes: 49 additions & 22 deletions open_prices/proofs/management/commands/run_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,79 @@

import tqdm
from django.core.management.base import BaseCommand
from django.db.models import Q
from openfoodfacts.utils import get_logger

from open_prices.proofs.ml.image_classifier import (
from open_prices.proofs.ml import (
PRICE_TAG_DETECTOR_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_NAME,
run_and_save_proof_prediction,
)
from open_prices.proofs.models import Proof

# Initializing root logger
get_logger()


class Command(BaseCommand):
help = """Run ML models on images with proof predictions, and save the predictions
in DB."""
_allowed_types = ["proof_classification"]
_allowed_types = ["proof_classification", "price_tag_detection"]

def add_arguments(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--limit", type=int, help="Limit the number of proofs to process."
)
parser.add_argument("type", type=str, help="Type of model to run.", nargs="+")
parser.add_argument(
"--types",
type=str,
help="Type of model to run. Supported values are `proof_classification` "
"and `price_tag_detection`",
)

def handle(self, *args, **options) -> None: # type: ignore
self.stdout.write(
"Running ML models on images without proof predictions for this model..."
)
limit = options["limit"]
types = options["type"]
types_str = options["types"]

if types_str:
types = types_str.split(",")
else:
types = self._allowed_types

if not all(t in self._allowed_types for t in types):
raise ValueError(
f"Invalid type(s) provided: {types}, allowed: {self._allowed_types}"
f"Invalid type(s) provided: '{types}', allowed: {self._allowed_types}"
)

exclusion_filters_list = []
if "proof_classification" in types:
# Get proofs that don't have a proof prediction with
# model_name = PROOF_CLASSIFICATION_MODEL_NAME by performing an
# outer join on the Proof and Prediction tables.
proofs = (
Proof.objects.filter(predictions__model_name__isnull=True)
| Proof.objects.exclude(
predictions__model_name=PROOF_CLASSIFICATION_MODEL_NAME
)
).distinct()

if limit:
proofs = proofs[:limit]

for proof in tqdm.tqdm(proofs):
self.stdout.write(f"Processing proof {proof.id}...")
run_and_save_proof_prediction(proof.id)
self.stdout.write("Done.")
exclusion_filters_list.append(
Q(predictions__model_name=PROOF_CLASSIFICATION_MODEL_NAME)
)

if "price_tag_detection" in types:
exclusion_filters_list.append(
Q(predictions__model_name=PRICE_TAG_DETECTOR_MODEL_NAME)
)

exclusion_filter = exclusion_filters_list.pop()
for remaining_filter in exclusion_filters_list:
exclusion_filter &= remaining_filter
# Get proofs that don't have a proof prediction with
# one of the model by performing an
# outer join on the Proof and Prediction tables.
proofs = (
Proof.objects.filter(predictions__model_name__isnull=True)
| Proof.objects.exclude(exclusion_filter)
).distinct()

if limit:
proofs = proofs[:limit]

for proof in tqdm.tqdm(proofs):
self.stdout.write(f"Processing proof {proof.id}...")
run_and_save_proof_prediction(proof.id)
self.stdout.write("Done.")
55 changes: 51 additions & 4 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,46 @@ def detect_price_tags(
)


def run_and_save_price_tag_detection(image: Image, proof: Proof) -> None:
def run_and_save_price_tag_detection(
image: Image, proof: Proof, overwrite: bool = False
) -> ProofPrediction | None:
"""Run the price tag object detection model and save the prediction
in ProofPrediction table.
:param image: the image to run the model on
:param proof: the Proof instance to associate the ProofPrediction with
:param overwrite: whether to overwrite existing prediction, defaults to
False
:return: the ProofPrediction instance created, or None if the prediction
already exists and overwrite is False
"""

if ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).exists():
if overwrite:
logger.info(
"Overwriting existing price tag detection for proof %s", proof.id
)
ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).delete()
else:
logger.debug(
"Proof %s already has a prediction for model %s",
proof.id,
PRICE_TAG_DETECTOR_MODEL_NAME,
)
return None

result = detect_price_tags(image)
detections = result.to_list()
if detections:
max_confidence = max(detections, key=lambda x: x["score"])["score"]
else:
max_confidence = None

ProofPrediction.objects.create(
return ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
Expand All @@ -111,18 +136,40 @@ def run_and_save_price_tag_detection(image: Image, proof: Proof) -> None:
)


def run_and_save_proof_type_prediction(image: Image, proof: Proof) -> None:
def run_and_save_proof_type_prediction(
image: Image, proof: Proof, overwrite: bool = False
) -> ProofPrediction | None:
"""Run the proof type classifier model and save the prediction in
ProofPrediction table.
:param image: the image to run the model on
:param proof: the Proof instance to associate the ProofPrediction with
:param overwrite: whether to overwrite existing prediction, defaults to
False
:return: the ProofPrediction instance created, or None if the prediction
already exists and overwrite is False
"""
if ProofPrediction.objects.filter(
proof=proof, model_name=PROOF_CLASSIFICATION_MODEL_NAME
).exists():
if overwrite:
logger.info("Overwriting existing type prediction for proof %s", proof.id)
ProofPrediction.objects.filter(
proof=proof, model_name=PROOF_CLASSIFICATION_MODEL_NAME
).delete()
else:
logger.debug(
"Proof %s already has a prediction for model %s",
proof.id,
PROOF_CLASSIFICATION_MODEL_NAME,
)
return None

prediction = predict_proof_type(image)

max_confidence = max(prediction, key=lambda x: x[1])[1]
proof_type = max(prediction, key=lambda x: x[1])[0]
ProofPrediction.objects.create(
return ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
model_name=PROOF_CLASSIFICATION_MODEL_NAME,
Expand Down
33 changes: 32 additions & 1 deletion open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
from open_prices.locations.factories import LocationFactory
from open_prices.prices.factories import PriceFactory
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.factories import ProofFactory, ProofPredictionFactory
from open_prices.proofs.ml import (
PRICE_TAG_DETECTOR_MODEL_NAME,
PRICE_TAG_DETECTOR_MODEL_VERSION,
PROOF_CLASSIFICATION_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_VERSION,
ObjectDetectionRawResult,
run_and_save_price_tag_detection,
run_and_save_proof_prediction,
run_and_save_proof_type_prediction,
)
from open_prices.proofs.models import Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir
Expand Down Expand Up @@ -474,6 +480,31 @@ def test_run_and_save_proof_prediction_proof(self):
price_tag_prediction.delete()
proof.delete()

def test_run_and_save_proof_type_prediction_already_exists(self):
image = Image.new("RGB", (100, 100), "white")

proof = ProofFactory()
ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
model_name=PROOF_CLASSIFICATION_MODEL_NAME,
model_version=PROOF_CLASSIFICATION_MODEL_VERSION,
)
result = run_and_save_proof_type_prediction(image, proof)
self.assertIsNone(result)

def test_run_and_save_price_tag_detection_already_exists(self):
image = Image.new("RGB", (100, 100), "white")
proof = ProofFactory()
ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
)
result = run_and_save_price_tag_detection(image, proof)
self.assertIsNone(result)


class TestSelectProofImageDir(TestCase):
def test_select_proof_image_dir_no_dir(self):
Expand Down

0 comments on commit c825d2c

Please sign in to comment.