Skip to content

Commit

Permalink
feat: create price tags from the object detector model (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 authored Dec 17, 2024
1 parent fe3f745 commit a0c4741
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 10 deletions.
57 changes: 50 additions & 7 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image

from . import constants
from .models import Proof, ProofPrediction
from .models import PriceTag, Proof, ProofPrediction

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,6 +86,41 @@ def detect_price_tags(
)


def create_price_tags_from_proof_prediction(
proof: Proof, proof_prediction: ProofPrediction, threshold: float = 0.5
) -> list[PriceTag]:
"""Create price tags from a proof prediction containing price tag object
detections.
:param proof: the Proof instance to associate the PriceTag instances with
:param proof_prediction: the ProofPrediction instance containing the
price tag detections
:param threshold: the minimum confidence threshold for a detection to be
considered valid, defaults to 0.5
:return: the list of PriceTag instances created
"""
if proof_prediction.model_name != PRICE_TAG_DETECTOR_MODEL_NAME:
logger.error(
"Proof prediction model %s is not a price tag detector",
proof_prediction.model_name,
)
return []

created = []
for detected_object in proof_prediction.data["objects"]:
if detected_object["score"] >= threshold:
price_tag = PriceTag.objects.create(
proof=proof,
bounding_box=detected_object["bounding_box"],
model_version=proof_prediction.model_version,
status=None,
created_by=None,
updated_by=None,
)
created.append(price_tag)
return created


def run_and_save_price_tag_detection(
image: Image, proof: Proof, overwrite: bool = False
) -> ProofPrediction | None:
Expand All @@ -100,22 +135,28 @@ def run_and_save_price_tag_detection(
already exists and overwrite is False
"""

if ProofPrediction.objects.filter(
proof_prediction = ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).exists():
).first()

if proof_prediction:
if overwrite:
logger.info(
"Overwriting existing price tag detection for proof %s", proof.id
)
ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).delete()
proof_prediction.delete()
else:
logger.debug(
"Proof %s already has a prediction for model %s",
proof.id,
PRICE_TAG_DETECTOR_MODEL_NAME,
)
if not PriceTag.objects.filter(proof=proof).exists():
logger.debug(
"Creating price tags from existing prediction for proof %s",
proof.id,
)
create_price_tags_from_proof_prediction(proof, proof_prediction)
return None

result = detect_price_tags(image)
Expand All @@ -125,7 +166,7 @@ def run_and_save_price_tag_detection(
else:
max_confidence = None

return ProofPrediction.objects.create(
proof_prediction = ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
Expand All @@ -134,6 +175,8 @@ def run_and_save_price_tag_detection(
value=None,
max_confidence=max_confidence,
)
create_price_tags_from_proof_prediction(proof, proof_prediction)
return proof_prediction


def run_and_save_proof_type_prediction(
Expand Down
75 changes: 72 additions & 3 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from django.core.exceptions import ValidationError
from django.test import TestCase
from django.utils import timezone
from PIL import Image

from open_prices.locations import constants as location_constants
Expand All @@ -25,11 +26,12 @@
PROOF_CLASSIFICATION_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_VERSION,
ObjectDetectionRawResult,
create_price_tags_from_proof_prediction,
run_and_save_price_tag_detection,
run_and_save_proof_prediction,
run_and_save_proof_type_prediction,
)
from open_prices.proofs.models import Proof
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir

LOCATION_OSM_NODE_652825274 = {
Expand Down Expand Up @@ -412,7 +414,9 @@ def test_run_and_save_proof_prediction_proof(self):

# change temporarily settings.IMAGE_DIR
with self.settings(IMAGE_DIR=NEW_IMAGE_DIR):
proof = ProofFactory(file_path=file_path)
proof = ProofFactory(
file_path=file_path, type=proof_constants.TYPE_PRICE_TAG
)

# Patch predict_proof_type to return a fixed response
with (
Expand Down Expand Up @@ -499,15 +503,80 @@ def test_run_and_save_proof_type_prediction_already_exists(self):

def test_run_and_save_price_tag_detection_already_exists(self):
image = Image.new("RGB", (100, 100), "white")
proof = ProofFactory()
proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG)
ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={
"objects": [
{
"label": "price_tag",
"score": 0.98,
"bounding_box": [0.5, 0.5, 1.0, 1.0],
},
{
"label": "price_tag",
"score": 0.8,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
]
},
)
result = run_and_save_price_tag_detection(image, proof)
self.assertIsNone(result)
price_tags = PriceTag.objects.filter(proof=proof).all()
self.assertEqual(len(price_tags), 2)
self.assertEqual(price_tags[0].bounding_box, [0.5, 0.5, 1.0, 1.0])
self.assertEqual(price_tags[1].bounding_box, [0.1, 0.1, 0.2, 0.2])

def create_price_tags_from_proof_prediction(self):
proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG)
proof_prediction = ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={
"objects": [
{
"label": "price_tag",
"score": 0.98,
"bounding_box": [0.5, 0.5, 1.0, 1.0],
},
{
"label": "price_tag",
"score": 0.45,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
{
"label": "price_tag",
"score": 0.4,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
]
},
)
before = timezone.now()
results = create_price_tags_from_proof_prediction(
proof, proof_prediction, threshold=0.4
)
after = timezone.now()
self.assertEqual(len(results), 2)
price_tags = PriceTag.objects.filter(proof=proof).all()
self.assertEqual(len(price_tags), 2)

price_tag_1 = results[0]
self.assertEqual(price_tag_1.bounding_box, [0.5, 0.5, 1.0, 1.0])
self.assertGreater(price_tag_1.created, before)
self.assertLess(price_tag_1.created, after)
self.assertGreater(price_tag_1.updated, before)
self.assertLess(price_tag_1.updated, after)
self.assertEqual(price_tag_1.status, None)
self.assertEqual(price_tag_1.created_by, None)
self.assertEqual(price_tag_1.updated_by, None)
self.assertEqual(price_tag_1.model_version, PRICE_TAG_DETECTOR_MODEL_VERSION)


class TestSelectProofImageDir(TestCase):
Expand Down

0 comments on commit a0c4741

Please sign in to comment.