Skip to content

Commit

Permalink
refactor(API): allow anyone to access proof data (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Dec 5, 2024
1 parent 0ecca7b commit 1b4dcb5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 53 deletions.
53 changes: 15 additions & 38 deletions open_prices/api/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,17 @@ def setUpTestData(cls):

def test_proof_list(self):
# anonymous
response = self.client.get(self.url)
self.assertEqual(response.status_code, 403)
# wrong token
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}X"}
)
self.assertEqual(response.status_code, 403)
# authenticated
# thanks to select_related and prefetch_related, we only have 6
# thanks to select_related and prefetch_related, we only have 3
# queries:
# - 1 to get the fetch the user session
# - 1 to update the session
# - 1 to get the user
# - 1 to count the number of proofs of the user
# - 1 to get the proofs and their associated locations (select_related)
# - 1 to get the associated proof predictions (prefetch_related)
with self.assertNumQueries(6):
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
with self.assertNumQueries(3):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
data = response.data
self.assertEqual(data["total"], 2) # only user's proofs
self.assertEqual(len(data["items"]), 2)
self.assertEqual(data["total"], 3)
self.assertEqual(len(data["items"]), 3)
item = data["items"][0]
self.assertEqual(item["id"], self.proof.id) # default order
self.assertIn("predictions", item)
Expand Down Expand Up @@ -122,10 +109,8 @@ def setUpTestData(cls):

def test_proof_list_order_by(self):
url = self.url + "?order_by=-price_count"
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
self.assertEqual(response.data["total"], 2)
response = self.client.get(url)
self.assertEqual(response.data["total"], 3)
self.assertEqual(response.data["items"][0]["price_count"], 50)


Expand All @@ -146,12 +131,16 @@ def setUpTestData(cls):

def test_proof_list_filter_by_type(self):
url = self.url + "?type=RECEIPT"
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
response = self.client.get(url)
self.assertEqual(response.data["total"], 1)
self.assertEqual(response.data["items"][0]["price_count"], 15)

def test_proof_list_filter_by_owner(self):
url = self.url + f"?owner={self.user_session.user.user_id}"
response = self.client.get(url)
self.assertEqual(response.data["total"], 2)
self.assertEqual(response.data["items"][0]["price_count"], 15)


class ProofDetailApiTest(TestCase):
@classmethod
Expand All @@ -166,23 +155,11 @@ def setUpTestData(cls):
def test_proof_detail(self):
# 404
url = reverse("api:proofs-detail", args=[999])
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session_1.token}"}
)
response = self.client.get(url)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.data["detail"], "No Proof matches the given query.")
# anonymous
response = self.client.get(self.url)
self.assertEqual(response.status_code, 403)
# wrong token
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}X"}
)
self.assertEqual(response.status_code, 403)
# authenticated
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["id"], self.proof.id)

Expand Down
30 changes: 15 additions & 15 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rest_framework import filters, mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.parsers import MultiPartParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.permissions import IsAuthenticatedOrReadOnly
from rest_framework.request import Request
from rest_framework.response import Response

Expand All @@ -31,28 +31,28 @@ class ProofViewSet(
viewsets.GenericViewSet,
):
authentication_classes = [CustomAuthentication]
permission_classes = [IsAuthenticated]
permission_classes = [IsAuthenticatedOrReadOnly]
http_method_names = ["get", "post", "patch", "delete"] # disable "put"
queryset = Proof.objects.none()
queryset = Proof.objects.all()
serializer_class = ProofFullSerializer
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
filterset_class = ProofFilter
ordering_fields = ["date", "price_count", "created"]
ordering = ["created"]

def get_queryset(self):
# only return proofs owned by the current user
if self.request.user.is_authenticated:
queryset = Proof.objects.filter(owner=self.request.user.user_id)
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return queryset.select_related("location").prefetch_related(
"predictions"
)
return queryset
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return self.queryset.select_related("location").prefetch_related(
"predictions"
)
elif self.request.method in ["PATCH", "DELETE"]:
# only return proofs owned by the current user
if self.request.user.is_authenticated:
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset

def get_serializer_class(self):
Expand Down

0 comments on commit 1b4dcb5

Please sign in to comment.