Skip to content

Commit

Permalink
Fixing user class reload with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djpugh committed Dec 20, 2020
1 parent 3e6d059 commit 7d8256f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/fastapi_aad_auth/_base/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Authentication State Handler."""
from enum import Enum
import importlib
import json
from typing import List, Optional
import uuid
Expand Down Expand Up @@ -43,6 +44,11 @@ def permissions(self):
permissions.append(scope)
return permissions[:]

@property
def klass(self):
"""Return the user klass information for loading from a session."""
return f'{self.__class__.__module__}:{self.__class__.__name__}'

@validator('scopes', always=True, pre=True)
def _validate_scopes(cls, value):
if isinstance(value, str):
Expand All @@ -52,14 +58,27 @@ def _validate_scopes(cls, value):

class AuthenticationState(LoggingMixin, InheritableBaseModel):
"""Authentication State."""
_logger = None
session_state: str = str(uuid.uuid4())
state: AuthenticationOptions = AuthenticationOptions.unauthenticated
user: Optional[User] = None
_logger = None

class Config: # noqa: D106
underscore_attrs_are_private = True

@validator('user', always=True, pre=True)
def _validate_user_klass(cls, value):
if isinstance(value, dict):
klass = value.get('klass', None)
if klass:
module, name = klass.split(':')
mod = importlib.import_module(module)
klass = getattr(mod, name)
else:
klass = User
value = klass(**value)
return value

@root_validator(pre=True)
def _validate_user(cls, values):
if values.get('user', None) is None:
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/test_auth_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List
import unittest
import uuid

from fastapi_aad_auth._base.state import AuthenticationState, User, AuthenticationOptions
from fastapi_aad_auth._base.validators import SessionValidator


class User2(User):
b: int = 2


class User3(User2):

@property
def permissions(self):
return [self.name, 'a']


class AuthenticationStateTestCase(unittest.TestCase):

def setUp(self):
self.serializer = SessionValidator.get_session_serializer(str(uuid.uuid4()), str(uuid.uuid4()))

def test_create(self):
user = User(name='Joe Bloggs', email='[email protected]', username='[email protected]')
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
self.assertIsInstance(state.user, User)
self.assertEqual(state.user.name, user.name)

def test_create_custom_user(self):
user = User2(name='Joe Bloggs', email='[email protected]', username='[email protected]')
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
self.assertIsInstance(state.user, User2)
self.assertEqual(state.user.name, user.name)
self.assertEqual(state.user.b, user.b)
self.assertEqual(state.user.b, 2)

def test_load(self):
user = User(name='Joe Bloggs', email='[email protected]', username='[email protected]')
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
self.assertIsInstance(state.user, User)
self.assertEqual(state.user.name, user.name)

def test_load_custom_user(self):
user = User2(name='Joe Bloggs', email='[email protected]', username='[email protected]')
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
self.assertIsInstance(state.user, User2)
self.assertEqual(state.user.name, user.name)
self.assertEqual(state.user.b, user.b)
self.assertEqual(state.user.b, 2)

def test_load_custom_user_permissions(self):
user = User3(name='Joe Bloggs', email='[email protected]', username='[email protected]', b=4)
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
self.assertIsInstance(state.user, User3)
self.assertEqual(state.user.name, user.name)
self.assertEqual(state.user.b, user.b)
self.assertEqual(state.user.b, 4)
self.assertEqual(state.user.permissions, ['Joe Bloggs', 'a'])



0 comments on commit 7d8256f

Please sign in to comment.