From d813e0d2c8e4616ce004c6e48bbc3503c28d96b8 Mon Sep 17 00:00:00 2001 From: Raphael Odini Date: Sun, 22 Dec 2024 14:15:44 +0100 Subject: [PATCH] Re-add extract_from_price_tags for history --- open_prices/proofs/ml.py | 46 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/open_prices/proofs/ml.py b/open_prices/proofs/ml.py index 2ed002fe..0a359297 100644 --- a/open_prices/proofs/ml.py +++ b/open_prices/proofs/ml.py @@ -158,6 +158,10 @@ class Label(typing.TypedDict): product_name: str +class Labels(typing.TypedDict): + labels: list[Label] + + def extract_from_price_tag(image: Image.Image) -> Label: """Extract price tag information from an image. @@ -192,6 +196,48 @@ def extract_from_price_tag(image: Image.Image) -> Label: return json.loads(response.text) +def extract_from_price_tags(images: Image.Image) -> Labels: + """ + Extract price tag information from a list of images. + + Warning: + Gemini sometimes skips some images when prediction price tag labels, + leading to mismatch between price tag and predictions. + Use extract_from_price_tag instead. + """ + + # Gemini model max payload size is 20MB + # To prevent the payload from being too large, we resize the images before + # upload + resized_images = [] + max_size = 1024 + for image in images: + if image.width > max_size or image.height > max_size: + resized_image = image.copy() + resized_image.thumbnail((max_size, max_size)) + resized_images.append(resized_image) + else: + resized_images.append(image) + + response = model.generate_content( + [ + ( + f"Here are {len(resized_images)} pictures containing a label. " + "For each picture of a label, please extract all the following attributes: " + "the product category matching product name, the origin category matching country of origin, the price, " + "is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode (valid EAN-13 usually). " + f"I expect a list of {len(resized_images)} labels in your reply, no more, no less. " + "If you cannot decode an attribute, set it to an empty string" + ) + ] + + resized_images, + generation_config=genai.GenerationConfig( + response_mime_type="application/json", response_schema=Labels + ), + ) + return json.loads(response.text) + + def predict_proof_type( image: Image.Image, model_name: str = PROOF_CLASSIFICATION_MODEL_NAME,