From 7d8256fb5ef234b0bb1a4771bbe970edda746566 Mon Sep 17 00:00:00 2001 From: David Pugh Date: Sun, 20 Dec 2020 22:27:53 +0000 Subject: [PATCH] Fixing user class reload with tests --- src/fastapi_aad_auth/_base/state.py | 21 ++++++++- tests/unit/test_auth_state.py | 66 +++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_auth_state.py diff --git a/src/fastapi_aad_auth/_base/state.py b/src/fastapi_aad_auth/_base/state.py index 0afbf6b..5c909f8 100644 --- a/src/fastapi_aad_auth/_base/state.py +++ b/src/fastapi_aad_auth/_base/state.py @@ -1,5 +1,6 @@ """Authentication State Handler.""" from enum import Enum +import importlib import json from typing import List, Optional import uuid @@ -43,6 +44,11 @@ def permissions(self): permissions.append(scope) return permissions[:] + @property + def klass(self): + """Return the user klass information for loading from a session.""" + return f'{self.__class__.__module__}:{self.__class__.__name__}' + @validator('scopes', always=True, pre=True) def _validate_scopes(cls, value): if isinstance(value, str): @@ -52,14 +58,27 @@ def _validate_scopes(cls, value): class AuthenticationState(LoggingMixin, InheritableBaseModel): """Authentication State.""" + _logger = None session_state: str = str(uuid.uuid4()) state: AuthenticationOptions = AuthenticationOptions.unauthenticated user: Optional[User] = None - _logger = None class Config: # noqa: D106 underscore_attrs_are_private = True + @validator('user', always=True, pre=True) + def _validate_user_klass(cls, value): + if isinstance(value, dict): + klass = value.get('klass', None) + if klass: + module, name = klass.split(':') + mod = importlib.import_module(module) + klass = getattr(mod, name) + else: + klass = User + value = klass(**value) + return value + @root_validator(pre=True) def _validate_user(cls, values): if values.get('user', None) is None: diff --git a/tests/unit/test_auth_state.py b/tests/unit/test_auth_state.py new file mode 100644 index 0000000..306de9e --- /dev/null +++ b/tests/unit/test_auth_state.py @@ -0,0 +1,66 @@ +from typing import List +import unittest +import uuid + +from fastapi_aad_auth._base.state import AuthenticationState, User, AuthenticationOptions +from fastapi_aad_auth._base.validators import SessionValidator + + +class User2(User): + b: int = 2 + + +class User3(User2): + + @property + def permissions(self): + return [self.name, 'a'] + + +class AuthenticationStateTestCase(unittest.TestCase): + + def setUp(self): + self.serializer = SessionValidator.get_session_serializer(str(uuid.uuid4()), str(uuid.uuid4())) + + def test_create(self): + user = User(name='Joe Bloggs', email='joe.bloggs@gmail.com', username='joe.bloggs@gmail.com') + state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated) + self.assertIsInstance(state.user, User) + self.assertEqual(state.user.name, user.name) + + def test_create_custom_user(self): + user = User2(name='Joe Bloggs', email='joe.bloggs@gmail.com', username='joe.bloggs@gmail.com') + state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated) + self.assertIsInstance(state.user, User2) + self.assertEqual(state.user.name, user.name) + self.assertEqual(state.user.b, user.b) + self.assertEqual(state.user.b, 2) + + def test_load(self): + user = User(name='Joe Bloggs', email='joe.bloggs@gmail.com', username='joe.bloggs@gmail.com') + state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated) + loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer)) + self.assertIsInstance(state.user, User) + self.assertEqual(state.user.name, user.name) + + def test_load_custom_user(self): + user = User2(name='Joe Bloggs', email='joe.bloggs@gmail.com', username='joe.bloggs@gmail.com') + state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated) + loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer)) + self.assertIsInstance(state.user, User2) + self.assertEqual(state.user.name, user.name) + self.assertEqual(state.user.b, user.b) + self.assertEqual(state.user.b, 2) + + def test_load_custom_user_permissions(self): + user = User3(name='Joe Bloggs', email='joe.bloggs@gmail.com', username='joe.bloggs@gmail.com', b=4) + state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated) + loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer)) + self.assertIsInstance(state.user, User3) + self.assertEqual(state.user.name, user.name) + self.assertEqual(state.user.b, user.b) + self.assertEqual(state.user.b, 4) + self.assertEqual(state.user.permissions, ['Joe Bloggs', 'a']) + + +