Skip to content

Commit

Permalink
fix: process one image at a time with Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 22, 2024
1 parent 5d7862f commit 697cc8a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 33 deletions.
6 changes: 3 additions & 3 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from open_prices.api.utils import get_source_from_request
from open_prices.common.authentication import CustomAuthentication
from open_prices.common.constants import PriceTagStatus
from open_prices.proofs.ml import extract_from_price_tags
from open_prices.proofs.ml import extract_from_price_tag
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import store_file

Expand Down Expand Up @@ -126,8 +126,8 @@ def upload(self, request: Request) -> Response:
def process_with_gemini(self, request: Request) -> Response:
files = request.FILES.getlist("files")
sample_files = [PIL.Image.open(file.file) for file in files]
res = extract_from_price_tags(sample_files)
return Response(res, status=status.HTTP_200_OK)
labels = [extract_from_price_tag(sample_file) for sample_file in sample_files]
return Response({"labels": labels}, status=status.HTTP_200_OK)


class PriceTagViewSet(
Expand Down
51 changes: 21 additions & 30 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class Products(enum.Enum):
CUCUMBERS = "en:cucumbers"
DATES = "en:dates"
ENDIVES = "en:endives"
FENNEL_BULBS = "en:fennel-bulbs"
FIGS = "en:figs"
GARLIC = "en:garlic"
GINGER = "en:ginger"
Expand Down Expand Up @@ -157,40 +158,35 @@ class Label(typing.TypedDict):
product_name: str


class Labels(typing.TypedDict):
labels: list[Label]
def extract_from_price_tag(image: Image.Image) -> Label:
"""Extract price tag information from an image.

def extract_from_price_tags(images: Image.Image) -> Labels:
"""Extract price tag information from a list of images."""
:param image: the input Pillow image
:return: the extracted information as a dictionary
"""

# Gemini model max payload size is 20MB
# To prevent the payload from being too large, we resize the images before
# upload
resized_images = []
max_size = 1024
for image in images:
if image.width > max_size or image.height > max_size:
resized_image = image.copy()
resized_image.thumbnail((max_size, max_size))
resized_images.append(resized_image)
else:
resized_images.append(image)
if image.width > max_size or image.height > max_size:
image = image.copy()
image.thumbnail((max_size, max_size))

response = model.generate_content(
[
(
f"Here are {len(resized_images)} pictures containing a label. "
"For each picture of a label, please extract all the following attributes: "
"Here is one picture containing a label. "
"Please extract all the following attributes: "
"the product category matching product name, the origin category matching country of origin, the price, "
"is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode (valid EAN-13 usually). "
f"I expect a list of {len(resized_images)} labels in your reply, no more, no less. "
"If you cannot decode an attribute, set it to an empty string"
)
]
+ resized_images,
"I expect a single JSON in your reply, no more, no less. "
"If you cannot decode an attribute, set it to an empty string."
),
image,
],
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Labels
response_mime_type="application/json", response_schema=Label
),
)
return json.loads(response.text)
Expand Down Expand Up @@ -285,7 +281,7 @@ def run_and_save_price_tag_extraction(
logger.error("Proof file not found: %s", proof.file_path_full)
return []

cropped_images = []
predictions = []
for price_tag in price_tags:
y_min, x_min, y_max, x_max = price_tag.bounding_box
image = Image.open(proof.file_path_full)
Expand All @@ -296,12 +292,7 @@ def run_and_save_price_tag_extraction(
y_max * image.height,
)
cropped_image = image.crop((left, top, right, bottom))
cropped_images.append(cropped_image)

labels = extract_from_price_tags(cropped_images)

predictions = []
for price_tag, label in zip(price_tags, labels["labels"]):
label = extract_from_price_tag(cropped_image)
prediction = PriceTagPrediction.objects.create(
price_tag=price_tag,
type=constants.PRICE_TAG_EXTRACTION_TYPE,
Expand Down Expand Up @@ -351,8 +342,8 @@ def update_price_tag_extraction(price_tag_id: int) -> PriceTagPrediction:
y_max * image.height,
)
cropped_image = image.crop((left, top, right, bottom))
gemini_output = extract_from_price_tags([cropped_image])
price_tag_prediction.data = gemini_output["labels"][0]
gemini_output = extract_from_price_tag(cropped_image)
price_tag_prediction.data = gemini_output
price_tag_prediction.model_name = GEMINI_MODEL_NAME
price_tag_prediction.model_version = GEMINI_MODEL_VERSION
price_tag_prediction.save()
Expand Down

0 comments on commit 697cc8a

Please sign in to comment.