diff --git a/open_prices/api/proofs/tests.py b/open_prices/api/proofs/tests.py index 91256b43..a2ef9c82 100644 --- a/open_prices/api/proofs/tests.py +++ b/open_prices/api/proofs/tests.py @@ -69,30 +69,17 @@ def setUpTestData(cls): def test_proof_list(self): # anonymous - response = self.client.get(self.url) - self.assertEqual(response.status_code, 403) - # wrong token - response = self.client.get( - self.url, headers={"Authorization": f"Bearer {self.user_session.token}X"} - ) - self.assertEqual(response.status_code, 403) - # authenticated - # thanks to select_related and prefetch_related, we only have 6 + # thanks to select_related and prefetch_related, we only have 3 # queries: - # - 1 to get the fetch the user session - # - 1 to update the session - # - 1 to get the user # - 1 to count the number of proofs of the user # - 1 to get the proofs and their associated locations (select_related) # - 1 to get the associated proof predictions (prefetch_related) - with self.assertNumQueries(6): - response = self.client.get( - self.url, headers={"Authorization": f"Bearer {self.user_session.token}"} - ) + with self.assertNumQueries(3): + response = self.client.get(self.url) self.assertEqual(response.status_code, 200) data = response.data - self.assertEqual(data["total"], 2) # only user's proofs - self.assertEqual(len(data["items"]), 2) + self.assertEqual(data["total"], 3) + self.assertEqual(len(data["items"]), 3) item = data["items"][0] self.assertEqual(item["id"], self.proof.id) # default order self.assertIn("predictions", item) @@ -122,10 +109,8 @@ def setUpTestData(cls): def test_proof_list_order_by(self): url = self.url + "?order_by=-price_count" - response = self.client.get( - url, headers={"Authorization": f"Bearer {self.user_session.token}"} - ) - self.assertEqual(response.data["total"], 2) + response = self.client.get(url) + self.assertEqual(response.data["total"], 3) self.assertEqual(response.data["items"][0]["price_count"], 50) @@ -146,12 +131,16 @@ def setUpTestData(cls): def test_proof_list_filter_by_type(self): url = self.url + "?type=RECEIPT" - response = self.client.get( - url, headers={"Authorization": f"Bearer {self.user_session.token}"} - ) + response = self.client.get(url) self.assertEqual(response.data["total"], 1) self.assertEqual(response.data["items"][0]["price_count"], 15) + def test_proof_list_filter_by_owner(self): + url = self.url + f"?owner={self.user_session.user.user_id}" + response = self.client.get(url) + self.assertEqual(response.data["total"], 2) + self.assertEqual(response.data["items"][0]["price_count"], 15) + class ProofDetailApiTest(TestCase): @classmethod @@ -166,23 +155,11 @@ def setUpTestData(cls): def test_proof_detail(self): # 404 url = reverse("api:proofs-detail", args=[999]) - response = self.client.get( - url, headers={"Authorization": f"Bearer {self.user_session_1.token}"} - ) + response = self.client.get(url) self.assertEqual(response.status_code, 404) self.assertEqual(response.data["detail"], "No Proof matches the given query.") # anonymous response = self.client.get(self.url) - self.assertEqual(response.status_code, 403) - # wrong token - response = self.client.get( - self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}X"} - ) - self.assertEqual(response.status_code, 403) - # authenticated - response = self.client.get( - self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}"} - ) self.assertEqual(response.status_code, 200) self.assertEqual(response.data["id"], self.proof.id) diff --git a/open_prices/api/proofs/views.py b/open_prices/api/proofs/views.py index 35eeff34..6cbb33b5 100644 --- a/open_prices/api/proofs/views.py +++ b/open_prices/api/proofs/views.py @@ -4,7 +4,7 @@ from rest_framework import filters, mixins, status, viewsets from rest_framework.decorators import action from rest_framework.parsers import MultiPartParser -from rest_framework.permissions import IsAuthenticated +from rest_framework.permissions import IsAuthenticatedOrReadOnly from rest_framework.request import Request from rest_framework.response import Response @@ -31,9 +31,9 @@ class ProofViewSet( viewsets.GenericViewSet, ): authentication_classes = [CustomAuthentication] - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticatedOrReadOnly] http_method_names = ["get", "post", "patch", "delete"] # disable "put" - queryset = Proof.objects.none() + queryset = Proof.objects.all() serializer_class = ProofFullSerializer filter_backends = [DjangoFilterBackend, filters.OrderingFilter] filterset_class = ProofFilter @@ -41,18 +41,18 @@ class ProofViewSet( ordering = ["created"] def get_queryset(self): - # only return proofs owned by the current user - if self.request.user.is_authenticated: - queryset = Proof.objects.filter(owner=self.request.user.user_id) - if self.request.method in ["GET"]: - # Select all proofs along with their locations using a select - # related query (1 single query) - # Then prefetch all the predictions related to the proof using - # a prefetch related query (only 1 query for all proofs) - return queryset.select_related("location").prefetch_related( - "predictions" - ) - return queryset + if self.request.method in ["GET"]: + # Select all proofs along with their locations using a select + # related query (1 single query) + # Then prefetch all the predictions related to the proof using + # a prefetch related query (only 1 query for all proofs) + return self.queryset.select_related("location").prefetch_related( + "predictions" + ) + elif self.request.method in ["PATCH", "DELETE"]: + # only return proofs owned by the current user + if self.request.user.is_authenticated: + return self.queryset.filter(owner=self.request.user.user_id) return self.queryset def get_serializer_class(self):