diff --git a/open_prices/api/prices/views.py b/open_prices/api/prices/views.py index 0bebf2de..2d1d79ed 100644 --- a/open_prices/api/prices/views.py +++ b/open_prices/api/prices/views.py @@ -43,7 +43,7 @@ def get_queryset(self): elif self.request.method in ["PATCH", "DELETE"]: # only return prices owned by the current user if self.request.user.is_authenticated: - return Price.objects.filter(owner=self.request.user.user_id) + return self.queryset.filter(owner=self.request.user.user_id) return self.queryset def get_serializer_class(self): diff --git a/open_prices/api/proofs/serializers.py b/open_prices/api/proofs/serializers.py index 626f3f1d..197bae15 100644 --- a/open_prices/api/proofs/serializers.py +++ b/open_prices/api/proofs/serializers.py @@ -30,6 +30,14 @@ class Meta: exclude = ["location", "source"] +class ProofHalfFullSerializer(ProofSerializer): + location = LocationSerializer() + + class Meta: + model = Proof + exclude = ["source"] # ProofSerializer.Meta.exclude + + class ProofFullSerializer(ProofSerializer): location = LocationSerializer() predictions = ProofPredictionSerializer(many=True, read_only=True) diff --git a/open_prices/api/proofs/tests.py b/open_prices/api/proofs/tests.py index a2ef9c82..2c1f777b 100644 --- a/open_prices/api/proofs/tests.py +++ b/open_prices/api/proofs/tests.py @@ -69,12 +69,10 @@ def setUpTestData(cls): def test_proof_list(self): # anonymous - # thanks to select_related and prefetch_related, we only have 3 - # queries: + # thanks to select_related, we only have 2 queries: # - 1 to count the number of proofs of the user # - 1 to get the proofs and their associated locations (select_related) - # - 1 to get the associated proof predictions (prefetch_related) - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.get(self.url) self.assertEqual(response.status_code, 200) data = response.data @@ -82,14 +80,7 @@ def test_proof_list(self): self.assertEqual(len(data["items"]), 3) item = data["items"][0] self.assertEqual(item["id"], self.proof.id) # default order - self.assertIn("predictions", item) - self.assertEqual(len(item["predictions"]), 1) - prediction = item["predictions"][0] - self.assertEqual(prediction["type"], self.proof_prediction.type) - self.assertEqual(prediction["model_name"], self.proof_prediction.model_name) - self.assertEqual( - prediction["model_version"], self.proof_prediction.model_version - ) + self.assertNotIn("predictions", item) # not returned in "list" class ProofListOrderApiTest(TestCase): @@ -100,7 +91,7 @@ def setUpTestData(cls): cls.proof = ProofFactory( **PROOF, price_count=15, owner=cls.user_session.user.user_id ) - ProofFactory(price_count=0) + ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0) ProofFactory( type=proof_constants.TYPE_PRICE_TAG, price_count=50, @@ -122,7 +113,7 @@ def setUpTestData(cls): cls.proof = ProofFactory( **PROOF, price_count=15, owner=cls.user_session.user.user_id ) - ProofFactory(price_count=0) + ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0) ProofFactory( type=proof_constants.TYPE_PRICE_TAG, price_count=50, @@ -150,6 +141,9 @@ def setUpTestData(cls): cls.proof = ProofFactory( **PROOF, price_count=15, owner=cls.user_session_1.user.user_id ) + cls.proof_prediction = ProofPredictionFactory( + proof=cls.proof, type="CLASSIFICATION" + ) cls.url = reverse("api:proofs-detail", args=[cls.proof.id]) def test_proof_detail(self): @@ -162,6 +156,14 @@ def test_proof_detail(self): response = self.client.get(self.url) self.assertEqual(response.status_code, 200) self.assertEqual(response.data["id"], self.proof.id) + self.assertIn("predictions", response.data) # returned in "detail" + self.assertEqual(len(response.data["predictions"]), 1) + prediction = response.data["predictions"][0] + self.assertEqual(prediction["type"], self.proof_prediction.type) + self.assertEqual(prediction["model_name"], self.proof_prediction.model_name) + self.assertEqual( + prediction["model_version"], self.proof_prediction.model_version + ) class ProofCreateApiTest(TestCase): diff --git a/open_prices/api/proofs/views.py b/open_prices/api/proofs/views.py index 6cbb33b5..9352f35b 100644 --- a/open_prices/api/proofs/views.py +++ b/open_prices/api/proofs/views.py @@ -12,6 +12,7 @@ from open_prices.api.proofs.serializers import ( ProofCreateSerializer, ProofFullSerializer, + ProofHalfFullSerializer, ProofProcessWithGeminiSerializer, ProofUpdateSerializer, ProofUploadSerializer, @@ -41,23 +42,22 @@ class ProofViewSet( ordering = ["created"] def get_queryset(self): + queryset = self.queryset if self.request.method in ["GET"]: - # Select all proofs along with their locations using a select - # related query (1 single query) - # Then prefetch all the predictions related to the proof using - # a prefetch related query (only 1 query for all proofs) - return self.queryset.select_related("location").prefetch_related( - "predictions" - ) + queryset = queryset.select_related("location") + if self.action == "retrieve": + queryset = queryset.prefetch_related("predictions") elif self.request.method in ["PATCH", "DELETE"]: # only return proofs owned by the current user if self.request.user.is_authenticated: - return self.queryset.filter(owner=self.request.user.user_id) - return self.queryset + queryset = queryset.filter(owner=self.request.user.user_id) + return queryset def get_serializer_class(self): if self.request.method == "PATCH": return ProofUpdateSerializer + elif self.action == "list": + return ProofHalfFullSerializer return self.serializer_class def destroy(self, request: Request, *args, **kwargs) -> Response: