Skip to content

Commit

Permalink
✨ Retrieve extra user info from claims from zgw_auth_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
damm89 committed Mar 25, 2024
1 parent 7f86c65 commit c3ea97d
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 3 deletions.
83 changes: 82 additions & 1 deletion src/dowc/accounts/authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import logging
from typing import Dict

from django.contrib.auth import get_user_model
from django.utils.translation import ugettext_lazy as _

from requests.models import Request
from rest_framework import exceptions
from rest_framework.authentication import TokenAuthentication as _TokenAuthentication
from rest_framework.authentication import (
TokenAuthentication as _TokenAuthentication,
get_authorization_header,
)
from zgw_auth_backend.authentication import ZGWAuthentication as _ZGWAuthentication
from zgw_auth_backend.zgw import ZGWAuth

logger = logging.getLogger(__name__)


class ApplicationTokenAuthentication(_TokenAuthentication):
Expand All @@ -16,3 +28,72 @@ def authenticate_credentials(self, key):
raise exceptions.AuthenticationFailed(_("Invalid token."))

return (None, token)


class ZGWAuthentication(_ZGWAuthentication):
"""
Taken from zgw_auth_backend and adapted to further suit our needs.
We want to include first and last names and check every authentication
if an update is needed to reflect changes done to their first
and last name.
"""

def authenticate(self, request: Request):
auth = get_authorization_header(request).split()

if not auth or auth[0].lower() != b"bearer":
return None

if len(auth) == 1:
msg = _("Invalid bearer header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _(
"Invalid bearer header. Credentials string should not contain spaces."
)
raise exceptions.AuthenticationFailed(msg)

auth = ZGWAuth(auth[1].decode("utf-8"))

user_id = auth.payload.get("user_id")
if not user_id:
msg = _("Invalid 'user_id' claim. The 'user_id' should not be empty.")
raise exceptions.AuthenticationFailed(msg)

email = auth.payload.get("email", "")
return self.authenticate_user_id(user_id, email, auth.payload)

def authenticate_user_id(self, username: str, email: str, payload: Dict):
UserModel = get_user_model()
fields = {UserModel.USERNAME_FIELD: username}
user, created = UserModel._default_manager.get_or_create(**fields)
if created:
msg = "Created user object for username %s" % username
logger.info(msg)

if email:
email_field = UserModel.get_email_field_name()
email_value = getattr(user, email_field)
if not email_value or email_value != email:
setattr(user, email_field, email)
user.save()
msg = "Set email to %s of user with username %s" % (email, username)
logger.info(msg)

extra_user_info_fields = ["first_name", "last_name"]
data = {
field: value
for field, value in payload.items()
if field in extra_user_info_fields
}
for field, value in data.items():
if not getattr(user, field) == value:
setattr(user, field, value)
try:
user.save(update_fields=[field])
except ValueError:
logger.error(exc_info=True)
continue

return (user, None)
1 change: 1 addition & 0 deletions src/dowc/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_magic_url(self, obj) -> str:

if obj.purpose in [DocFileTypes.read, DocFileTypes.write]:
fn, fext = os.path.splitext(obj.document.name)
print("wut?")
if scheme_name := EXTENSION_HANDLER.get(fext, ""):
command_argument = {
DocFileTypes.read: ":ofv|u|",
Expand Down
76 changes: 75 additions & 1 deletion src/dowc/api/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
"""
Test that authorization is required for the API endpoints.
Test that authorization creates or gets the user.
"""
import uuid
from unittest.mock import patch

from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.reverse import reverse, reverse_lazy
from rest_framework.test import APITestCase
from zds_client import ClientAuth
from zgw_auth_backend.models import ApplicationCredentials
from zgw_consumers.api_models.base import factory
from zgw_consumers.api_models.documenten import Document
from zgw_consumers.constants import APITypes
from zgw_consumers.models import Service
from zgw_consumers.test import generate_oas_component, mock_service_oas_get

from dowc.accounts.models import User
from dowc.accounts.tests.factories import UserFactory
from dowc.core.constants import DocFileTypes


class AuthTests(APITestCase):
Expand Down Expand Up @@ -40,3 +54,63 @@ def test_invalid_token(self):
with self.subTest(method=method, path=path):
response = getattr(self.client, method)(path, **headers)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_create_user_during_authentication(self):
drc_url = "https://some.drc.nl/api/v1/"
Service.objects.create(api_type=APITypes.drc, api_root=drc_url)
list_url = reverse_lazy("documentfile-list")

# Create mock url for drc object
_uuid = str(uuid.uuid4())
doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}"

# No users exist
self.assertEqual(User.objects.count(), 0)
data = {
"drc_url": doc_url,
"purpose": DocFileTypes.read,
"info_url": "http://www.some-referer-url.com/",
"user_id": "some-user",
}
ApplicationCredentials.objects.create(client_id="dummy", secret="secret")
auth = ClientAuth("dummy", "secret", user_id="some-user").credentials()

response = self.client.post(
list_url, data, HTTP_AUTHORIZATION=auth["Authorization"]
)

self.assertEqual(User.objects.get().username, "some-user")

def test_update_user_during_authentication(self):
drc_url = "https://some.drc.nl/api/v1/"
Service.objects.create(api_type=APITypes.drc, api_root=drc_url)
list_url = reverse_lazy("documentfile-list")

# Create mock url for drc object
_uuid = str(uuid.uuid4())
doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}"

# User exists
user = UserFactory.create(
username="some-user", first_name="First", last_name="Last"
)
self.assertEqual(User.objects.count(), 1)
data = {
"drc_url": doc_url,
"purpose": DocFileTypes.read,
"info_url": "http://www.some-referer-url.com/",
}
ApplicationCredentials.objects.create(client_id="dummy", secret="secret")
auth = ClientAuth(
"dummy",
"secret",
user_id="some-user",
first_name="some other first",
last_name="some other last",
).credentials()

response = self.client.post(
list_url, data, HTTP_AUTHORIZATION=auth["Authorization"]
)

self.assertEqual(User.objects.get().first_name, "some other first")
2 changes: 1 addition & 1 deletion src/dowc/conf/includes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework.authentication.TokenAuthentication",
"zgw_auth_backend.authentication.ZGWAuthentication",
"dowc.accounts.authentication.ZGWAuthentication",
),
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
"DEFAULT_FILTER_BACKENDS": [
Expand Down

0 comments on commit c3ea97d

Please sign in to comment.