diff --git a/src/fastapi_aad_auth/_base/provider.py b/src/fastapi_aad_auth/_base/provider.py index a6546cb..95aeac0 100644 --- a/src/fastapi_aad_auth/_base/provider.py +++ b/src/fastapi_aad_auth/_base/provider.py @@ -1,5 +1,6 @@ from typing import List, Optional +from pydantic import PrivateAttr from starlette.requests import Request from starlette.responses import RedirectResponse from starlette.routing import Route @@ -7,7 +8,7 @@ from fastapi_aad_auth._base.authenticators import SessionAuthenticator from fastapi_aad_auth._base.validators import Validator from fastapi_aad_auth.mixins import LoggingMixin -from fastapi_aad_auth.utilities import urls +from fastapi_aad_auth.utilities import InheritableBaseSettings, urls class Provider(LoggingMixin): @@ -71,3 +72,9 @@ def redirect_url(self): if self._redirect_url is None: self._redirect_url = self._build_oauth_url(self.oauth_base_route, 'redirect') return self._redirect_url + + +class ProviderConfig(InheritableBaseSettings): + """Configuration for a provider.""" + + _provider_klass: type = PrivateAttr(Provider) diff --git a/src/fastapi_aad_auth/_base/state.py b/src/fastapi_aad_auth/_base/state.py index e8207bf..0afbf6b 100644 --- a/src/fastapi_aad_auth/_base/state.py +++ b/src/fastapi_aad_auth/_base/state.py @@ -6,11 +6,12 @@ from itsdangerous import URLSafeSerializer from itsdangerous.exc import BadSignature -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import Field, root_validator, validator from starlette.authentication import AuthCredentials, SimpleUser, UnauthenticatedUser from fastapi_aad_auth.errors import AuthenticationError from fastapi_aad_auth.mixins import LoggingMixin +from fastapi_aad_auth.utilities import InheritableBaseModel, InheritablePropertyBaseModel SESSION_STORE_KEY = 'auth' @@ -23,7 +24,7 @@ class AuthenticationOptions(Enum): authenticated = 1 -class User(BaseModel): +class User(InheritablePropertyBaseModel): """User Model.""" name: str = Field(..., description='Full name') email: str = Field(..., description='User email') @@ -49,7 +50,7 @@ def _validate_scopes(cls, value): return value -class AuthenticationState(LoggingMixin, BaseModel): +class AuthenticationState(LoggingMixin, InheritableBaseModel): """Authentication State.""" session_state: str = str(uuid.uuid4()) state: AuthenticationOptions = AuthenticationOptions.unauthenticated diff --git a/src/fastapi_aad_auth/config.py b/src/fastapi_aad_auth/config.py index 6b426f3..0527065 100644 --- a/src/fastapi_aad_auth/config.py +++ b/src/fastapi_aad_auth/config.py @@ -1,10 +1,11 @@ """fastapi_aad_auth configuration options.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import uuid from pkg_resources import resource_filename from pydantic import BaseSettings as _BaseSettings, DirectoryPath, Field, FilePath, SecretStr, validator +from fastapi_aad_auth._base.provider import ProviderConfig from fastapi_aad_auth.providers.aad import AADConfig from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, DeprecatedField, expand_doc, klass_from_str @@ -138,7 +139,7 @@ class Config(BaseSettings): """ enabled: bool = Field(True, description="Enable authentication", env='FASTAPI_AUTH_ENABLED') - providers: List[Union[AADConfig]] = Field(None, description="The provider configurations to use") + providers: List[ProviderConfig] = Field(None, description="The provider configurations to use") aad: Optional[AADConfig] = DeprecatedField(None, description='AAD Configuration information', deprecated_in='0.2.0', replaced_by='Config.providers') auth_session: AuthSessionConfig = Field(None, description="The configuration for encoding the authentication information in the session") routing: RoutingConfig = Field(None, description="Configuration for routing") @@ -158,6 +159,7 @@ def _validate_providers(cls, value, values): value = [] if enabled: value.append(AADConfig(_env_file=cls.Config.env_file)) + return value @validator('aad', always=True, pre=True) diff --git a/src/fastapi_aad_auth/providers/aad.py b/src/fastapi_aad_auth/providers/aad.py index 653c075..e843e7a 100644 --- a/src/fastapi_aad_auth/providers/aad.py +++ b/src/fastapi_aad_auth/providers/aad.py @@ -15,7 +15,7 @@ from starlette.requests import Request from fastapi_aad_auth._base.authenticators import SessionAuthenticator -from fastapi_aad_auth._base.provider import Provider +from fastapi_aad_auth._base.provider import Provider, ProviderConfig from fastapi_aad_auth._base.state import User from fastapi_aad_auth._base.validators import SessionValidator, TokenValidator from fastapi_aad_auth.errors import ConfigurationError @@ -340,7 +340,7 @@ def get_login_button(self, post_redirect='/'): @expand_doc -class AADConfig(BaseSettings): +class AADConfig(ProviderConfig): """Configuration for the AAD application. Includes expected claims, application registration, etc. diff --git a/src/fastapi_aad_auth/utilities/__init__.py b/src/fastapi_aad_auth/utilities/__init__.py index ba8ea64..6ff9603 100644 --- a/src/fastapi_aad_auth/utilities/__init__.py +++ b/src/fastapi_aad_auth/utilities/__init__.py @@ -7,6 +7,7 @@ from fastapi_aad_auth.utilities import logging # noqa: F401 from fastapi_aad_auth.utilities import urls # noqa: F401 +from fastapi_aad_auth.utilities.basemodel import InheritableBaseModel, InheritableBaseSettings, InheritablePropertyBaseModel, InheritablePropertyBaseSettings, PropertyBaseModel, PropertyBaseSettings # noqa: F401 from fastapi_aad_auth.utilities.deprecate import DeprecatableFieldsMixin, deprecate, deprecate_module, DeprecatedField, is_deprecated # noqa: F401 diff --git a/src/fastapi_aad_auth/utilities/basemodel.py b/src/fastapi_aad_auth/utilities/basemodel.py new file mode 100644 index 0000000..f3ab4ee --- /dev/null +++ b/src/fastapi_aad_auth/utilities/basemodel.py @@ -0,0 +1,103 @@ +"""Provide inheritable property dictable basemodel. + +Implements pydantic work arounds for: +* https://github.com/samuelcolvin/pydantic/issues/265 +* https://github.com/samuelcolvin/pydantic/issues/935 + +""" +from functools import wraps + +from pydantic import BaseModel, BaseSettings +from pydantic.validators import dict_validator + + +class InheritableMixin: + """BaseModel that will Validate with inheritance rather than the original Class.""" + + @classmethod + def get_validators(cls): + """Get the validator for the object.""" + yield cls.validate + + @classmethod + def validate(cls, value): + """Validate the class as itself.""" + if isinstance(value, cls): + return value + else: + return cls(**dict_validator(value)) + + +class PropertyMixin: + """BaseModel with Properties in dict. + + A Pydantic BaseModel that includes properties in it's dict() result + enabling a mix of both fields and properties + """ + + @classmethod + def get_properties(cls): + """Get the properties.""" + return [prop for prop in dir(cls) if cls._is_property(prop)] + + @classmethod + def _is_property(cls, prop): + return isinstance(getattr(cls, prop), property) \ + and prop not in ("__values__", "fields") + + @wraps(BaseModel.dict) + def dict(self, + *, + include=None, + exclude=None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False,): + """Return the object as a dictionary.""" + attribs = super().dict( # type: ignore + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none + ) + props = self.get_properties() + # Include and exclude properties + if include: + props = [prop for prop in props if prop in include] + if exclude: + props = [prop for prop in props if prop not in exclude] + + # Update the attribute dict with the properties + if props: + attribs.update({prop: getattr(self, prop) for prop in props}) + + return attribs + + +class InheritableBaseSettings(InheritableMixin, BaseSettings): + """A Pydantic BaseSettings that allows inheritance.""" + + +class PropertyBaseSettings(PropertyMixin, BaseSettings): + """A Pydantic BaseSettings that allows roperties in the dict.""" + + +class InheritablePropertyBaseSettings(InheritableMixin, PropertyBaseSettings): + """A Pydantic BaseSettings that allows inheritance and properties in the dict.""" + + +class InheritableBaseModel(InheritableMixin, BaseModel): + """A Pydantic BaseModel that allows inheritance.""" + + +class PropertyBaseModel(PropertyMixin, BaseModel): + """A Pydantic BaseModel that allows roperties in the dict.""" + + +class InheritablePropertyBaseModel(InheritableMixin, PropertyBaseModel): + """A Pydantic BaseModel that allows inheritance and properties in the dict."""