Skip to content

Commit

Permalink
feat: save proof prediction for all proofs in a new table (#588)
Browse files Browse the repository at this point in the history
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
raphael0202 and dependabot[bot] authored Dec 4, 2024
1 parent a5c0df4 commit cb1fb7a
Show file tree
Hide file tree
Showing 12 changed files with 597 additions and 102 deletions.
13 changes: 13 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ POSTGRES_EXPOSE=127.0.0.1:5432
ENVIRONMENT=net

GUNICORN_WORKERS=1

# We use special `host.docker.internal` to access the localhost (=your laptop) from
# the docker container
# It works because we added the special `host.docker.internal:host-gateway`
# host in dev.yml for all services
# Triton is the ML inference server used at Open Food Facts
TRITON_URI=host.docker.internal:5004

# By default, don't enable ML predictions, as we don't necessarily have a Triton
# server running.
# During local development, to enable ML predictions, set this to True and make sure
# you have Triton running on port 5004.
ENABLE_ML_PREDICTIONS=False
6 changes: 6 additions & 0 deletions .github/workflows/container-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ jobs:
echo "SSH_HOST=10.1.0.200" >> $GITHUB_ENV
echo "ENVIRONMENT=net" >> $GITHUB_ENV
echo "CSRF_TRUSTED_ORIGINS=https://prices.openfoodfacts.net" >> $GITHUB_ENV
# Triton server is on the same datacenter as the staging server, so we use the internal IP
echo "TRITON_URI=10.1.0.200:5504" >> $GITHUB_ENV
- name: Set various variable for production deployment
if: matrix.env == 'open-prices-org'
run: |
echo "SSH_HOST=10.1.0.201" >> $GITHUB_ENV
echo "ENVIRONMENT=org" >> $GITHUB_ENV
echo "CSRF_TRUSTED_ORIGINS=https://prices.openfoodfacts.org" >> $GITHUB_ENV
# Triton server is on Moji datacenter, so we use the stunnel client running
# on the OVH datacenter to access it
echo "TRITON_URI=10.1.0.101:5504" >> $GITHUB_ENV
- name: Wait for docker image container build workflow
uses: tomchv/[email protected]
id: wait-build
Expand Down Expand Up @@ -133,6 +138,7 @@ jobs:
echo "ENVIRONMENT=${{ env.ENVIRONMENT }}" >> .env
echo "GOOGLE_CLOUD_VISION_API_KEY=${{ secrets.GOOGLE_CLOUD_VISION_API_KEY }}" >> .env
echo "GOOGLE_GEMINI_API_KEY=${{ secrets.GOOGLE_GEMINI_API_KEY }}" >> .env
echo "TRITON_URI=${{ env.TRITON_URI }}" >> .env
- name: Create Docker volumes
uses: appleboy/ssh-action@master
Expand Down
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ cli: guard-args
${DOCKER_COMPOSE} run --rm --no-deps api python3 manage.py ${args}


# TODO: migrate to Django
add-db-revision: guard-message
${DOCKER_COMPOSE} run --rm --no-deps api alembic revision --autogenerate -m "${message}"
makemigrations: guard-args
${DOCKER_COMPOSE} run --rm --no-deps api python3 manage.py makemigrations ${args}

#---------#
# Cleanup #
Expand Down
6 changes: 6 additions & 0 deletions config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,9 @@
# ------------------------------------------------------------------------------

GOOGLE_GEMINI_API_KEY = os.getenv("GOOGLE_GEMINI_API_KEY")

# Triton Inference Server (ML)
# ------------------------------------------------------------------------------

TRITON_URI = os.getenv("TRITON_URI", "localhost:5004")
ENABLE_ML_PREDICTIONS = os.getenv("ENABLE_ML_PREDICTIONS") == "True"
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ x-api-common: &api-common
- LOG_LEVEL
- GOOGLE_CLOUD_VISION_API_KEY
- GOOGLE_GEMINI_API_KEY
- TRITON_URI
- ENABLE_ML_PREDICTIONS
networks:
- default

Expand Down
3 changes: 3 additions & 0 deletions docker/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ x-api-base: &api-base
# mount tests
- ./tests:/opt/open-prices/tests
- ./manage.py:/opt/open-prices/manage.py
# Allow the container to access the host's network, for Triton server to be accessible
extra_hosts:
- "host.docker.internal:host-gateway"

services:
api:
Expand Down
13 changes: 13 additions & 0 deletions open_prices/proofs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,16 @@
TYPE_SINGLE_SHOP_LIST = [TYPE_PRICE_TAG, TYPE_RECEIPT, TYPE_SHOP_IMPORT]
TYPE_SHOPPING_SESSION_LIST = [TYPE_RECEIPT, TYPE_GDPR_REQUEST]
TYPE_MULTIPLE_SHOP_LIST = [TYPE_GDPR_REQUEST]

PROOF_PREDICTION_OBJECT_DETECTION_TYPE = "OBJECT_DETECTION"
PROOF_PREDICTION_CLASSIFICATION_TYPE = "CLASSIFICATION"
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE = "RECEIPT_EXTRACTION"
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION"
PROOF_PREDICTION_LIST = [
PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
PROOF_PREDICTION_CLASSIFICATION_TYPE,
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE,
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE,
]

PROOF_TYPE_CHOICES = [(key, key) for key in PROOF_PREDICTION_LIST]
101 changes: 101 additions & 0 deletions open_prices/proofs/migrations/0006_add_proof_prediction_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Generated by Django 5.1 on 2024-12-01 17:41

import django.db.models.deletion
import django.utils.timezone
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("proofs", "0005_proof_receipt_price_count_proof_receipt_price_total"),
]

operations = [
migrations.CreateModel(
name="ProofPrediction",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"type",
models.CharField(
choices=[
("OBJECT_DETECTION", "OBJECT_DETECTION"),
("CLASSIFICATION", "CLASSIFICATION"),
("RECEIPT_EXTRACTION", "RECEIPT_EXTRACTION"),
("PRICE_TAG_EXTRACTION", "PRICE_TAG_EXTRACTION"),
],
max_length=20,
verbose_name="The type of the prediction",
),
),
(
"model_name",
models.CharField(
max_length=30,
verbose_name="The name of the model that generated the prediction",
),
),
(
"model_version",
models.CharField(
max_length=30,
verbose_name="The specific version of the model that generated the prediction",
),
),
(
"created",
models.DateTimeField(
default=django.utils.timezone.now,
verbose_name="When the prediction was created in DB",
),
),
(
"data",
models.JSONField(
blank=True,
null=True,
verbose_name="a dict representing the data of the prediction. This field is model-specific.",
),
),
(
"value",
models.CharField(
blank=True,
max_length=30,
null=True,
verbose_name="The predicted value, only for classification models, null otherwise.",
),
),
(
"max_confidence",
models.FloatField(
blank=True,
null=True,
verbose_name="The maximum confidence of the prediction, may be null for some models.For object detection models, this is the confidence of the most confident object.For classification models, this is the confidence of the predicted class.",
),
),
(
"proof",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="predictions",
to="proofs.proof",
verbose_name="The proof this prediction belongs to",
),
),
],
options={
"verbose_name": "Proof Prediction",
"verbose_name_plural": "Proof Predictions",
"db_table": "proof_predictions",
},
),
]
Loading

0 comments on commit cb1fb7a

Please sign in to comment.