Skip to content

Commit

Permalink
refactor(API): only return proof.predictions in detail endpoint (#605)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Dec 5, 2024
1 parent 1b4dcb5 commit 81f7f0f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 24 deletions.
2 changes: 1 addition & 1 deletion open_prices/api/prices/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_queryset(self):
elif self.request.method in ["PATCH", "DELETE"]:
# only return prices owned by the current user
if self.request.user.is_authenticated:
return Price.objects.filter(owner=self.request.user.user_id)
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset

def get_serializer_class(self):
Expand Down
8 changes: 8 additions & 0 deletions open_prices/api/proofs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class Meta:
exclude = ["location", "source"]


class ProofHalfFullSerializer(ProofSerializer):
location = LocationSerializer()

class Meta:
model = Proof
exclude = ["source"] # ProofSerializer.Meta.exclude


class ProofFullSerializer(ProofSerializer):
location = LocationSerializer()
predictions = ProofPredictionSerializer(many=True, read_only=True)
Expand Down
30 changes: 16 additions & 14 deletions open_prices/api/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,18 @@ def setUpTestData(cls):

def test_proof_list(self):
# anonymous
# thanks to select_related and prefetch_related, we only have 3
# queries:
# thanks to select_related, we only have 2 queries:
# - 1 to count the number of proofs of the user
# - 1 to get the proofs and their associated locations (select_related)
# - 1 to get the associated proof predictions (prefetch_related)
with self.assertNumQueries(3):
with self.assertNumQueries(2):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
data = response.data
self.assertEqual(data["total"], 3)
self.assertEqual(len(data["items"]), 3)
item = data["items"][0]
self.assertEqual(item["id"], self.proof.id) # default order
self.assertIn("predictions", item)
self.assertEqual(len(item["predictions"]), 1)
prediction = item["predictions"][0]
self.assertEqual(prediction["type"], self.proof_prediction.type)
self.assertEqual(prediction["model_name"], self.proof_prediction.model_name)
self.assertEqual(
prediction["model_version"], self.proof_prediction.model_version
)
self.assertNotIn("predictions", item) # not returned in "list"


class ProofListOrderApiTest(TestCase):
Expand All @@ -100,7 +91,7 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session.user.user_id
)
ProofFactory(price_count=0)
ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0)
ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
price_count=50,
Expand All @@ -122,7 +113,7 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session.user.user_id
)
ProofFactory(price_count=0)
ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0)
ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
price_count=50,
Expand Down Expand Up @@ -150,6 +141,9 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session_1.user.user_id
)
cls.proof_prediction = ProofPredictionFactory(
proof=cls.proof, type="CLASSIFICATION"
)
cls.url = reverse("api:proofs-detail", args=[cls.proof.id])

def test_proof_detail(self):
Expand All @@ -162,6 +156,14 @@ def test_proof_detail(self):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["id"], self.proof.id)
self.assertIn("predictions", response.data) # returned in "detail"
self.assertEqual(len(response.data["predictions"]), 1)
prediction = response.data["predictions"][0]
self.assertEqual(prediction["type"], self.proof_prediction.type)
self.assertEqual(prediction["model_name"], self.proof_prediction.model_name)
self.assertEqual(
prediction["model_version"], self.proof_prediction.model_version
)


class ProofCreateApiTest(TestCase):
Expand Down
18 changes: 9 additions & 9 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from open_prices.api.proofs.serializers import (
ProofCreateSerializer,
ProofFullSerializer,
ProofHalfFullSerializer,
ProofProcessWithGeminiSerializer,
ProofUpdateSerializer,
ProofUploadSerializer,
Expand Down Expand Up @@ -41,23 +42,22 @@ class ProofViewSet(
ordering = ["created"]

def get_queryset(self):
queryset = self.queryset
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return self.queryset.select_related("location").prefetch_related(
"predictions"
)
queryset = queryset.select_related("location")
if self.action == "retrieve":
queryset = queryset.prefetch_related("predictions")
elif self.request.method in ["PATCH", "DELETE"]:
# only return proofs owned by the current user
if self.request.user.is_authenticated:
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset
queryset = queryset.filter(owner=self.request.user.user_id)
return queryset

def get_serializer_class(self):
if self.request.method == "PATCH":
return ProofUpdateSerializer
elif self.action == "list":
return ProofHalfFullSerializer
return self.serializer_class

def destroy(self, request: Request, *args, **kwargs) -> Response:
Expand Down

0 comments on commit 81f7f0f

Please sign in to comment.