Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create price tags from the object detector model #629

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image

from . import constants
from .models import Proof, ProofPrediction
from .models import PriceTag, Proof, ProofPrediction

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,6 +86,41 @@ def detect_price_tags(
)


def create_price_tags_from_proof_prediction(
proof: Proof, proof_prediction: ProofPrediction, threshold: float = 0.5
) -> list[PriceTag]:
"""Create price tags from a proof prediction containing price tag object
detections.

:param proof: the Proof instance to associate the PriceTag instances with
:param proof_prediction: the ProofPrediction instance containing the
price tag detections
:param threshold: the minimum confidence threshold for a detection to be
considered valid, defaults to 0.5
:return: the list of PriceTag instances created
"""
if proof_prediction.model_name != PRICE_TAG_DETECTOR_MODEL_NAME:
logger.error(
"Proof prediction model %s is not a price tag detector",
proof_prediction.model_name,
)
return []

created = []
for detected_object in proof_prediction.data["objects"]:
if detected_object["score"] >= threshold:
price_tag = PriceTag.objects.create(
proof=proof,
bounding_box=detected_object["bounding_box"],
model_version=proof_prediction.model_version,
status=None,
created_by=None,
updated_by=None,
)
created.append(price_tag)
return created


def run_and_save_price_tag_detection(
image: Image, proof: Proof, overwrite: bool = False
) -> ProofPrediction | None:
Expand All @@ -100,22 +135,28 @@ def run_and_save_price_tag_detection(
already exists and overwrite is False
"""

if ProofPrediction.objects.filter(
proof_prediction = ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).exists():
).first()

if proof_prediction:
if overwrite:
logger.info(
"Overwriting existing price tag detection for proof %s", proof.id
)
ProofPrediction.objects.filter(
proof=proof, model_name=PRICE_TAG_DETECTOR_MODEL_NAME
).delete()
proof_prediction.delete()
else:
logger.debug(
"Proof %s already has a prediction for model %s",
proof.id,
PRICE_TAG_DETECTOR_MODEL_NAME,
)
if not PriceTag.objects.filter(proof=proof).exists():
logger.debug(
"Creating price tags from existing prediction for proof %s",
proof.id,
)
create_price_tags_from_proof_prediction(proof, proof_prediction)
return None

result = detect_price_tags(image)
Expand All @@ -125,7 +166,7 @@ def run_and_save_price_tag_detection(
else:
max_confidence = None

return ProofPrediction.objects.create(
proof_prediction = ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
Expand All @@ -134,6 +175,8 @@ def run_and_save_price_tag_detection(
value=None,
max_confidence=max_confidence,
)
create_price_tags_from_proof_prediction(proof, proof_prediction)
return proof_prediction


def run_and_save_proof_type_prediction(
Expand Down
75 changes: 72 additions & 3 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from django.core.exceptions import ValidationError
from django.test import TestCase
from django.utils import timezone
from PIL import Image

from open_prices.locations import constants as location_constants
Expand All @@ -25,11 +26,12 @@
PROOF_CLASSIFICATION_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_VERSION,
ObjectDetectionRawResult,
create_price_tags_from_proof_prediction,
run_and_save_price_tag_detection,
run_and_save_proof_prediction,
run_and_save_proof_type_prediction,
)
from open_prices.proofs.models import Proof
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir

LOCATION_OSM_NODE_652825274 = {
Expand Down Expand Up @@ -412,7 +414,9 @@ def test_run_and_save_proof_prediction_proof(self):

# change temporarily settings.IMAGE_DIR
with self.settings(IMAGE_DIR=NEW_IMAGE_DIR):
proof = ProofFactory(file_path=file_path)
proof = ProofFactory(
file_path=file_path, type=proof_constants.TYPE_PRICE_TAG
)

# Patch predict_proof_type to return a fixed response
with (
Expand Down Expand Up @@ -499,15 +503,80 @@ def test_run_and_save_proof_type_prediction_already_exists(self):

def test_run_and_save_price_tag_detection_already_exists(self):
image = Image.new("RGB", (100, 100), "white")
proof = ProofFactory()
proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG)
ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={
"objects": [
{
"label": "price_tag",
"score": 0.98,
"bounding_box": [0.5, 0.5, 1.0, 1.0],
},
{
"label": "price_tag",
"score": 0.8,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
]
},
)
result = run_and_save_price_tag_detection(image, proof)
self.assertIsNone(result)
price_tags = PriceTag.objects.filter(proof=proof).all()
self.assertEqual(len(price_tags), 2)
self.assertEqual(price_tags[0].bounding_box, [0.5, 0.5, 1.0, 1.0])
self.assertEqual(price_tags[1].bounding_box, [0.1, 0.1, 0.2, 0.2])

def create_price_tags_from_proof_prediction(self):
proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG)
proof_prediction = ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={
"objects": [
{
"label": "price_tag",
"score": 0.98,
"bounding_box": [0.5, 0.5, 1.0, 1.0],
},
{
"label": "price_tag",
"score": 0.45,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
{
"label": "price_tag",
"score": 0.4,
"bounding_box": [0.1, 0.1, 0.2, 0.2],
},
]
},
)
before = timezone.now()
results = create_price_tags_from_proof_prediction(
proof, proof_prediction, threshold=0.4
)
after = timezone.now()
self.assertEqual(len(results), 2)
price_tags = PriceTag.objects.filter(proof=proof).all()
self.assertEqual(len(price_tags), 2)

price_tag_1 = results[0]
self.assertEqual(price_tag_1.bounding_box, [0.5, 0.5, 1.0, 1.0])
self.assertGreater(price_tag_1.created, before)
self.assertLess(price_tag_1.created, after)
self.assertGreater(price_tag_1.updated, before)
self.assertLess(price_tag_1.updated, after)
self.assertEqual(price_tag_1.status, None)
self.assertEqual(price_tag_1.created_by, None)
self.assertEqual(price_tag_1.updated_by, None)
self.assertEqual(price_tag_1.model_version, PRICE_TAG_DETECTOR_MODEL_VERSION)


class TestSelectProofImageDir(TestCase):
Expand Down
Loading