Skip to content

Commit

Permalink
feat(PricesStats): Added group_by stat api point
Browse files Browse the repository at this point in the history
  • Loading branch information
TTalex committed Dec 10, 2024
1 parent d0c9193 commit bf1d397
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 3 deletions.
7 changes: 7 additions & 0 deletions open_prices/api/prices/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class PriceFilter(django_filters.FilterSet):
product_id__isnull = django_filters.BooleanFilter(
field_name="product_id", lookup_expr="isnull"
)
product_labels_tags__contains = django_filters.CharFilter(
field_name="product__labels_tags", lookup_expr="icontains"
)
product_categories_tags__contains = django_filters.CharFilter(
field_name="product__categories_tags", lookup_expr="icontains"
)
labels_tags__contains = django_filters.CharFilter(
field_name="labels_tags", lookup_expr="icontains"
)
Expand Down Expand Up @@ -51,4 +57,5 @@ class Meta:
"date",
"proof_id",
"owner",
"proof__type",
]
20 changes: 20 additions & 0 deletions open_prices/api/prices/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,23 @@ class PriceStatsSerializer(serializers.Serializer):
price__min = serializers.DecimalField(max_digits=10, decimal_places=2)
price__max = serializers.DecimalField(max_digits=10, decimal_places=2)
price__avg = serializers.DecimalField(max_digits=10, decimal_places=2)
price__sum = serializers.DecimalField(max_digits=10, decimal_places=2)


class GroupedPriceStatsQuerySerializer(serializers.Serializer):
group_by = serializers.CharField(
required=True, help_text="Field by which to group the statistics"
)


class GroupedPriceStatsSerializer(PriceStatsSerializer):
# Override representation to dynamically include the group field
def to_representation(self, instance):
representation = super().to_representation(instance)

# Add the grouping field dynamically
for key in instance:
if key not in representation: # It's likely the group field
representation[key] = instance[key]

return representation
42 changes: 41 additions & 1 deletion open_prices/api/prices/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.exceptions import FieldError
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import filters, mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticatedOrReadOnly
Expand All @@ -8,6 +9,8 @@

from open_prices.api.prices.filters import PriceFilter
from open_prices.api.prices.serializers import (
GroupedPriceStatsQuerySerializer,
GroupedPriceStatsSerializer,
PriceCreateSerializer,
PriceFullSerializer,
PriceStatsSerializer,
Expand Down Expand Up @@ -79,3 +82,40 @@ def create(self, request: Request, *args, **kwargs):
def stats(self, request: Request) -> Response:
qs = self.filter_queryset(self.get_queryset())
return Response(qs.calculate_stats(), status=200)

@extend_schema(
request=GroupedPriceStatsQuerySerializer,
responses=GroupedPriceStatsSerializer(many=True),
filters=True,
parameters=[
OpenApiParameter(
name="group_by",
description="Field by which to group the statistics",
required=True,
type=str,
location=OpenApiParameter.QUERY,
)
],
)
@action(detail=False, methods=["GET"])
def grouped_stats(self, request: Request) -> Response:
qs = self.filter_queryset(self.get_queryset())

# Validate and parse query parameters using the serializer
serializer = GroupedPriceStatsQuerySerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
group_by = serializer.validated_data["group_by"]

try:
data = qs.calculate_grouped_stats(group_by)
except FieldError:
return Response(
{"detail": f"Invalid group_by field: {group_by}"},
status=status.HTTP_400_BAD_REQUEST,
)

# Apply pagination
paginator = self.paginator # Use the default pagination class
paginated_data = paginator.paginate_queryset(data, request, view=self)

return paginator.get_paginated_response(paginated_data)
32 changes: 30 additions & 2 deletions open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from django.core.validators import MinValueValidator, ValidationError
from django.db import models
from django.db.models import Avg, Count, F, Max, Min, signals
from django.db.models.functions import Cast
from django.db.models import Avg, Count, F, Max, Min, Sum, signals
from django.db.models.functions import Cast, TruncMonth, TruncWeek, TruncYear
from django.dispatch import receiver
from django.utils import timezone
from openfoodfacts.taxonomy import (
Expand Down Expand Up @@ -54,6 +54,34 @@ def calculate_stats(self):
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
),
price__sum=Sum("price"),
)

def calculate_grouped_stats(self, group_by):
group_by_list = group_by.split(",")
if (
"month" in group_by_list
or "year" in group_by_list
or "week" in group_by_list
):
queryset = self.annotate(
month=TruncMonth("date"), year=TruncYear("date"), week=TruncWeek("date")
)
else:
queryset = self
return (
queryset.values(*group_by_list)
.annotate(
price__count=Count("pk"),
price__min=Min("price"),
price__max=Max("price"),
price__avg=Cast(
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
),
price__sum=Sum("price"),
)
.order_by(*group_by_list)
)


Expand Down

0 comments on commit bf1d397

Please sign in to comment.