diff --git a/src/dowc/accounts/authentication.py b/src/dowc/accounts/authentication.py index f7d5b61..c00d314 100644 --- a/src/dowc/accounts/authentication.py +++ b/src/dowc/accounts/authentication.py @@ -1,7 +1,19 @@ +import logging +from typing import Dict + +from django.contrib.auth import get_user_model from django.utils.translation import ugettext_lazy as _ +from requests.models import Request from rest_framework import exceptions -from rest_framework.authentication import TokenAuthentication as _TokenAuthentication +from rest_framework.authentication import ( + TokenAuthentication as _TokenAuthentication, + get_authorization_header, +) +from zgw_auth_backend.authentication import ZGWAuthentication as _ZGWAuthentication +from zgw_auth_backend.zgw import ZGWAuth + +logger = logging.getLogger(__name__) class ApplicationTokenAuthentication(_TokenAuthentication): @@ -16,3 +28,72 @@ def authenticate_credentials(self, key): raise exceptions.AuthenticationFailed(_("Invalid token.")) return (None, token) + + +class ZGWAuthentication(_ZGWAuthentication): + """ + Taken from zgw_auth_backend and adapted to further suit our needs. + We want to include first and last names and check every authentication + if an update is needed to reflect changes done to their first + and last name. + + """ + + def authenticate(self, request: Request): + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != b"bearer": + return None + + if len(auth) == 1: + msg = _("Invalid bearer header. No credentials provided.") + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = _( + "Invalid bearer header. Credentials string should not contain spaces." + ) + raise exceptions.AuthenticationFailed(msg) + + auth = ZGWAuth(auth[1].decode("utf-8")) + + user_id = auth.payload.get("user_id") + if not user_id: + msg = _("Invalid 'user_id' claim. The 'user_id' should not be empty.") + raise exceptions.AuthenticationFailed(msg) + + email = auth.payload.get("email", "") + return self.authenticate_user_id(user_id, email, auth.payload) + + def authenticate_user_id(self, username: str, email: str, payload: Dict): + UserModel = get_user_model() + fields = {UserModel.USERNAME_FIELD: username} + user, created = UserModel._default_manager.get_or_create(**fields) + if created: + msg = "Created user object for username %s" % username + logger.info(msg) + + if email: + email_field = UserModel.get_email_field_name() + email_value = getattr(user, email_field) + if not email_value or email_value != email: + setattr(user, email_field, email) + user.save() + msg = "Set email to %s of user with username %s" % (email, username) + logger.info(msg) + + extra_user_info_fields = ["first_name", "last_name"] + data = { + field: value + for field, value in payload.items() + if field in extra_user_info_fields + } + for field, value in data.items(): + if not getattr(user, field) == value: + setattr(user, field, value) + try: + user.save(update_fields=[field]) + except ValueError: + logger.error(exc_info=True) + continue + + return (user, None) diff --git a/src/dowc/api/serializers.py b/src/dowc/api/serializers.py index 86309fd..0593275 100644 --- a/src/dowc/api/serializers.py +++ b/src/dowc/api/serializers.py @@ -151,6 +151,7 @@ def get_magic_url(self, obj) -> str: if obj.purpose in [DocFileTypes.read, DocFileTypes.write]: fn, fext = os.path.splitext(obj.document.name) + print("wut?") if scheme_name := EXTENSION_HANDLER.get(fext, ""): command_argument = { DocFileTypes.read: ":ofv|u|", diff --git a/src/dowc/api/tests/test_auth.py b/src/dowc/api/tests/test_auth.py index cb945b3..748236d 100644 --- a/src/dowc/api/tests/test_auth.py +++ b/src/dowc/api/tests/test_auth.py @@ -1,11 +1,25 @@ """ Test that authorization is required for the API endpoints. + +Test that authorization creates or gets the user. """ import uuid +from unittest.mock import patch from rest_framework import status -from rest_framework.reverse import reverse +from rest_framework.reverse import reverse, reverse_lazy from rest_framework.test import APITestCase +from zds_client import ClientAuth +from zgw_auth_backend.models import ApplicationCredentials +from zgw_consumers.api_models.base import factory +from zgw_consumers.api_models.documenten import Document +from zgw_consumers.constants import APITypes +from zgw_consumers.models import Service +from zgw_consumers.test import generate_oas_component, mock_service_oas_get + +from dowc.accounts.models import User +from dowc.accounts.tests.factories import UserFactory +from dowc.core.constants import DocFileTypes class AuthTests(APITestCase): @@ -40,3 +54,63 @@ def test_invalid_token(self): with self.subTest(method=method, path=path): response = getattr(self.client, method)(path, **headers) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_create_user_during_authentication(self): + drc_url = "https://some.drc.nl/api/v1/" + Service.objects.create(api_type=APITypes.drc, api_root=drc_url) + list_url = reverse_lazy("documentfile-list") + + # Create mock url for drc object + _uuid = str(uuid.uuid4()) + doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}" + + # No users exist + self.assertEqual(User.objects.count(), 0) + data = { + "drc_url": doc_url, + "purpose": DocFileTypes.read, + "info_url": "http://www.some-referer-url.com/", + "user_id": "some-user", + } + ApplicationCredentials.objects.create(client_id="dummy", secret="secret") + auth = ClientAuth("dummy", "secret", user_id="some-user").credentials() + + response = self.client.post( + list_url, data, HTTP_AUTHORIZATION=auth["Authorization"] + ) + + self.assertEqual(User.objects.get().username, "some-user") + + def test_update_user_during_authentication(self): + drc_url = "https://some.drc.nl/api/v1/" + Service.objects.create(api_type=APITypes.drc, api_root=drc_url) + list_url = reverse_lazy("documentfile-list") + + # Create mock url for drc object + _uuid = str(uuid.uuid4()) + doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}" + + # User exists + user = UserFactory.create( + username="some-user", first_name="First", last_name="Last" + ) + self.assertEqual(User.objects.count(), 1) + data = { + "drc_url": doc_url, + "purpose": DocFileTypes.read, + "info_url": "http://www.some-referer-url.com/", + } + ApplicationCredentials.objects.create(client_id="dummy", secret="secret") + auth = ClientAuth( + "dummy", + "secret", + user_id="some-user", + first_name="some other first", + last_name="some other last", + ).credentials() + + response = self.client.post( + list_url, data, HTTP_AUTHORIZATION=auth["Authorization"] + ) + + self.assertEqual(User.objects.get().first_name, "some other first") diff --git a/src/dowc/conf/includes/base.py b/src/dowc/conf/includes/base.py index 1aa10d7..32f2103 100644 --- a/src/dowc/conf/includes/base.py +++ b/src/dowc/conf/includes/base.py @@ -383,7 +383,7 @@ REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework.authentication.TokenAuthentication", - "zgw_auth_backend.authentication.ZGWAuthentication", + "dowc.accounts.authentication.ZGWAuthentication", ), "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "DEFAULT_FILTER_BACKENDS": [