diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 8ae4e47..e40def6 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -18,7 +18,7 @@ Customising the User Model ~~~~~~~~~~~~~~~~~~~~~~~~~~ The authentication state user can be processed within the application methods - the ``Depends`` part of the api route returns an -:class:`~fastapi_aad_auth.oauth.state.AuthenticationState` object - ``auth_state`` in the ``testapp`` (see :ref:`testing`). +:class:`~fastapi_aad_auth._base.state.AuthenticationState` object - ``auth_state`` in the ``testapp`` (see :ref:`testing`). .. literalinclude:: ../../tests/testapp/server.py :language: python @@ -28,26 +28,25 @@ The authentication state user can be processed within the application methods - The associated user is then available at ``auth_state.user`` -The :class:`~fastapi_aad_auth.oauth.aad.AADOAuthBackend` object takes a ``user_klass`` argument: +The :class:`~fastapi_aad_auth.auth.Authenticator` object takes a ``user_klass`` argument: -.. literalinclude:: ../../src/fastapi_aad_auth/oauth/aad.py +.. literalinclude:: ../../src/fastapi_aad_auth/auth.py :language: python :linenos: - :start-at: class AADOAuthBackend + :start-at: class Authenticator :end-before: """Initialise which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same interface should work, so you can add e.g. database calls etc. to validate/persist/check the user and any other desired behaviours. -You can customise this when initialising the :class:`~fastapi_aad_auth.auth.AADAuth` object by setting +You can customise this when initialising the :class:`~fastapi_aad_auth.auth.Authenticator` object by setting the :class:`~fastapi_aad_auth.config.Config` ``user_klass`` variable (this can also be done by the -associated environment variable):: +associated environment variable, or in the argument, which overrides all other settings):: - from fastapi_aad_auth import AADAuth, Config + from fastapi_aad_auth import Authenticator, Config config = Config() - config.user_klass = MyUserClass - auth = AADAuth(config) + auth = Authenticator(config, user_klass=MyUserClass) diff --git a/docs/source/conf.py b/docs/source/conf.py index df8a5ed..6b3b767 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -155,7 +155,7 @@ 'repo_name': 'fastapi_aad_auth', # Visible levels of the global TOC; -1 means unlimited - 'globaltoc_depth': -1, + 'globaltoc_depth': 1, # If False, expand all TOC entries 'globaltoc_collapse': False, # If True, show hidden TOC entries diff --git a/docs/source/deprecations.rst b/docs/source/deprecations.rst new file mode 100644 index 0000000..ab0fe53 --- /dev/null +++ b/docs/source/deprecations.rst @@ -0,0 +1,13 @@ +API Deprecations +**************** + +``0.2.0``: + + Refactoring towards a more extensible structure for the Authentication Backend + * :class:`fastapi_aad_auth.oauth.state.AuthenticationState` - replaced by :class:`fastapi_aad_auth._base.state.AuthenticationState` + * :class:`fastapi_aad_auth.oauth.aad.AADOAuthBackend` - replaced by :class:`fastapi_aad_auth.providers.aad.AADProvider` + * :py:attr:`fastapi_aad_auth.config.RoutingConfig.login_path` - replaced by provider based usage of :py:attr:`fastapi_aad_auth.config.RoutingConfig.oauth_base_route`, see :ref:`config-aad-appreg` for how to configure the app registration + * :py:attr:`fastapi_aad_auth.config.RoutingConfig.login_redirect_path` - replaced by provider based usage of :py:attr:`fastapi_aad_auth.config.RoutingConfig.oauth_base_route`, see :ref:`config-aad-appreg` for how to configure the app registration + * :py:attr:`fastapi_aad_auth.config.Config.aad` - replaced by providers in :py:attr:`fastapi_aad_auth.config.Config.providers` + * :class:`fastapi_aad_auth.auth.AADAuth` - replaced by :class:`fastapi_aad_auth.auth.Authenticator` + * :py:meth:`fastapi_aad_auth.auth.AADAuth.api_auth_scheme` - replaced by :py:meth:`fastapi_aad_auth._base.backend.BaseOAuthBackend.requires_auth` (includes ``allow_session`` boolean flag) diff --git a/docs/source/figures/App-Registration-Redirect-URIs.PNG b/docs/source/figures/App-Registration-Redirect-URIs.PNG index 063c5a3..c5c1be8 100644 Binary files a/docs/source/figures/App-Registration-Redirect-URIs.PNG and b/docs/source/figures/App-Registration-Redirect-URIs.PNG differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 3ba94a3..ac3e303 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,16 +50,23 @@ The repository is open source (MIT Licensed) on |github| `Github Basic Usage Advanced Usage + Deprecations + .. toctree:: :caption: API - :maxdepth: 4 + :maxdepth: 1 + module/fastapi_aad_auth._base module/fastapi_aad_auth.auth module/fastapi_aad_auth.config module/fastapi_aad_auth.errors + module/fastapi_aad_auth.mixins module/fastapi_aad_auth.oauth + module/fastapi_aad_auth.providers module/fastapi_aad_auth.ui + module/fastapi_aad_auth.utilities + .. toctree:: :caption: Changes and Contributing diff --git a/docs/source/module/fastapi_aad_auth._base.authenticators.rst b/docs/source/module/fastapi_aad_auth._base.authenticators.rst new file mode 100644 index 0000000..82df6df --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.authenticators.rst @@ -0,0 +1,10 @@ +fastapi_aad_auth._base.authenticators +************************************* + +.. automodule:: fastapi_aad_auth._base.authenticators + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth._base.authenticators.session diff --git a/docs/source/module/fastapi_aad_auth._base.authenticators.session.rst b/docs/source/module/fastapi_aad_auth._base.authenticators.session.rst new file mode 100644 index 0000000..c8162e5 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.authenticators.session.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.authenticators.session +********************************************* + +.. automodule:: fastapi_aad_auth._base.authenticators.session + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.backend.rst b/docs/source/module/fastapi_aad_auth._base.backend.rst new file mode 100644 index 0000000..5d7ce79 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.backend.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.backend +****************************** + +.. automodule:: fastapi_aad_auth._base.backend + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.provider.rst b/docs/source/module/fastapi_aad_auth._base.provider.rst new file mode 100644 index 0000000..705401c --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.provider.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.provider +******************************* + +.. automodule:: fastapi_aad_auth._base.provider + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.rst b/docs/source/module/fastapi_aad_auth._base.rst new file mode 100644 index 0000000..b242257 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.rst @@ -0,0 +1,14 @@ +fastapi_aad_auth._base +********************** + +.. automodule:: fastapi_aad_auth._base + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth._base.backend + fastapi_aad_auth._base.provider + fastapi_aad_auth._base.authenticators + fastapi_aad_auth._base.state + fastapi_aad_auth._base.validators \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.state.rst b/docs/source/module/fastapi_aad_auth._base.state.rst new file mode 100644 index 0000000..556c044 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.state.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.state +**************************** + +.. automodule:: fastapi_aad_auth._base.state + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.validators.base.rst b/docs/source/module/fastapi_aad_auth._base.validators.base.rst new file mode 100644 index 0000000..cb1c219 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.validators.base.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.validators.base +************************************** + +.. automodule:: fastapi_aad_auth._base.validators.base + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.validators.rst b/docs/source/module/fastapi_aad_auth._base.validators.rst new file mode 100644 index 0000000..4214180 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.validators.rst @@ -0,0 +1,12 @@ +fastapi_aad_auth._base.validators +********************************* + +.. automodule:: fastapi_aad_auth._base.validators + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth._base.validators.base + fastapi_aad_auth._base.validators.session + fastapi_aad_auth._base.validators.token diff --git a/docs/source/module/fastapi_aad_auth._base.validators.session.rst b/docs/source/module/fastapi_aad_auth._base.validators.session.rst new file mode 100644 index 0000000..11b6a56 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.validators.session.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.validators.session +***************************************** + +.. automodule:: fastapi_aad_auth._base.validators.session + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth._base.validators.token.rst b/docs/source/module/fastapi_aad_auth._base.validators.token.rst new file mode 100644 index 0000000..dfa9dee --- /dev/null +++ b/docs/source/module/fastapi_aad_auth._base.validators.token.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth._base.validators.token +*************************************** + +.. automodule:: fastapi_aad_auth._base.validators.token + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.mixins.logging.rst b/docs/source/module/fastapi_aad_auth.mixins.logging.rst new file mode 100644 index 0000000..ddae766 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.mixins.logging.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth.mixins.logging +******************************* + +.. automodule:: fastapi_aad_auth.mixins.logging + :members: diff --git a/docs/source/module/fastapi_aad_auth.mixins.not_authenticated.rst b/docs/source/module/fastapi_aad_auth.mixins.not_authenticated.rst new file mode 100644 index 0000000..286977e --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.mixins.not_authenticated.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth.mixins.not_authenticated +***************************************** + +.. automodule:: fastapi_aad_auth.mixins.not_authenticated + :members: diff --git a/docs/source/module/fastapi_aad_auth.mixins.rst b/docs/source/module/fastapi_aad_auth.mixins.rst new file mode 100644 index 0000000..f9335fa --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.mixins.rst @@ -0,0 +1,12 @@ +fastapi_aad_auth.mixins +*********************** + +.. automodule:: fastapi_aad_auth.mixins + :members: + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth.mixins.logging + fastapi_aad_auth.mixins.not_authenticated diff --git a/docs/source/module/fastapi_aad_auth.oauth._base.rst b/docs/source/module/fastapi_aad_auth.oauth._base.rst deleted file mode 100644 index 90921a1..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth._base.rst +++ /dev/null @@ -1,5 +0,0 @@ -fastapi_aad_auth.oauth._base -**************************** - -.. automodule:: fastapi_aad_auth.oauth._base - :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.oauth.aad.rst b/docs/source/module/fastapi_aad_auth.oauth.aad.rst index d0c6549..0b0c909 100644 --- a/docs/source/module/fastapi_aad_auth.oauth.aad.rst +++ b/docs/source/module/fastapi_aad_auth.oauth.aad.rst @@ -1,5 +1,5 @@ fastapi_aad_auth.oauth.aad ************************** - + .. automodule:: fastapi_aad_auth.oauth.aad - :members: \ No newline at end of file + :members: diff --git a/docs/source/module/fastapi_aad_auth.oauth.authenticators.rst b/docs/source/module/fastapi_aad_auth.oauth.authenticators.rst deleted file mode 100644 index 26d773e..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth.authenticators.rst +++ /dev/null @@ -1,10 +0,0 @@ -fastapi_aad_auth.oauth.authenticators -************************************* - -.. automodule:: fastapi_aad_auth.oauth.authenticators - -.. toctree:: - :caption: Sub-modules: - :maxdepth: 1 - - fastapi_aad_auth.oauth.authenticators.session diff --git a/docs/source/module/fastapi_aad_auth.oauth.authenticators.session.rst b/docs/source/module/fastapi_aad_auth.oauth.authenticators.session.rst deleted file mode 100644 index 98f12df..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth.authenticators.session.rst +++ /dev/null @@ -1,5 +0,0 @@ -fastapi_aad_auth.oauth.authenticators.session -********************************************* - -.. automodule:: fastapi_aad_auth.oauth.authenticators.session - :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.oauth.rst b/docs/source/module/fastapi_aad_auth.oauth.rst index 9978bd4..960ee66 100644 --- a/docs/source/module/fastapi_aad_auth.oauth.rst +++ b/docs/source/module/fastapi_aad_auth.oauth.rst @@ -7,8 +7,5 @@ fastapi_aad_auth.oauth :caption: Sub-modules: :maxdepth: 1 - fastapi_aad_auth.oauth._base fastapi_aad_auth.oauth.aad - fastapi_aad_auth.oauth.authenticators fastapi_aad_auth.oauth.state - fastapi_aad_auth.oauth.validators \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.oauth.state.rst b/docs/source/module/fastapi_aad_auth.oauth.state.rst index aedd426..4111829 100644 --- a/docs/source/module/fastapi_aad_auth.oauth.state.rst +++ b/docs/source/module/fastapi_aad_auth.oauth.state.rst @@ -1,5 +1,5 @@ fastapi_aad_auth.oauth.state **************************** - + .. automodule:: fastapi_aad_auth.oauth.state - :members: \ No newline at end of file + :members: diff --git a/docs/source/module/fastapi_aad_auth.oauth.validators.rst b/docs/source/module/fastapi_aad_auth.oauth.validators.rst deleted file mode 100644 index fd4dcf0..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth.validators.rst +++ /dev/null @@ -1,11 +0,0 @@ -fastapi_aad_auth.oauth.validators -********************************* - -.. automodule:: fastapi_aad_auth.oauth.validators - -.. toctree:: - :caption: Sub-modules: - :maxdepth: 1 - - fastapi_aad_auth.oauth.validators.session - fastapi_aad_auth.oauth.validators.token diff --git a/docs/source/module/fastapi_aad_auth.oauth.validators.session.rst b/docs/source/module/fastapi_aad_auth.oauth.validators.session.rst deleted file mode 100644 index 622124f..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth.validators.session.rst +++ /dev/null @@ -1,5 +0,0 @@ -fastapi_aad_auth.oauth.validators.session -***************************************** - -.. automodule:: fastapi_aad_auth.oauth.validators.session - :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.oauth.validators.token.rst b/docs/source/module/fastapi_aad_auth.oauth.validators.token.rst deleted file mode 100644 index 237d76f..0000000 --- a/docs/source/module/fastapi_aad_auth.oauth.validators.token.rst +++ /dev/null @@ -1,5 +0,0 @@ -fastapi_aad_auth.oauth.validators.token -*************************************** - -.. automodule:: fastapi_aad_auth.oauth.validators.token - :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.providers.aad.rst b/docs/source/module/fastapi_aad_auth.providers.aad.rst new file mode 100644 index 0000000..4daeea3 --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.providers.aad.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth.providers.aad +****************************** + +.. automodule:: fastapi_aad_auth.providers.aad + :members: \ No newline at end of file diff --git a/docs/source/module/fastapi_aad_auth.providers.rst b/docs/source/module/fastapi_aad_auth.providers.rst new file mode 100644 index 0000000..af14fad --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.providers.rst @@ -0,0 +1,11 @@ +fastapi_aad_auth.providers +************************** + +.. automodule:: fastapi_aad_auth.providers + :members: + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth.providers.aad diff --git a/docs/source/module/fastapi_aad_auth.utilities.logging.rst b/docs/source/module/fastapi_aad_auth.utilities.logging.rst new file mode 100644 index 0000000..10a8e0c --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.utilities.logging.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth.utilities.logging +********************************** + +.. automodule:: fastapi_aad_auth.utilities.logging + :members: diff --git a/docs/source/module/fastapi_aad_auth.utilities.rst b/docs/source/module/fastapi_aad_auth.utilities.rst new file mode 100644 index 0000000..4cf188f --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.utilities.rst @@ -0,0 +1,12 @@ +fastapi_aad_auth.utilities +************************** + +.. automodule:: fastapi_aad_auth.utilities + :members: + +.. toctree:: + :caption: Sub-modules: + :maxdepth: 1 + + fastapi_aad_auth.utilities.logging + fastapi_aad_auth.utilities.urls diff --git a/docs/source/module/fastapi_aad_auth.utilities.urls.rst b/docs/source/module/fastapi_aad_auth.utilities.urls.rst new file mode 100644 index 0000000..464afbe --- /dev/null +++ b/docs/source/module/fastapi_aad_auth.utilities.urls.rst @@ -0,0 +1,5 @@ +fastapi_aad_auth.utilities.urls +******************************* + +.. automodule:: fastapi_aad_auth.utilities.urls + :members: diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5164b33..cffcc96 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -25,11 +25,16 @@ e.g. for local development:: Configure a web platform for the UI based redirection (or whatever else is set in the config for the redirect path):: - https:///login/oauth/redirect + https:///oauth/aad/redirect e.g. for local development:: - http://localhost:8000/login/oauth/redirect + http://localhost:8000/oauth/aad/redirect + +.. warning:: + This is new behaviour that will be the default for version ``0.2.0``. To enable this now, you need to set the + :class:`~fastapi_aad_auth.config.RoutingConfig` ``login_path`` and ``login_redirect_path`` variables to the empty string or ``None`` + Youu also need to decide whether the application is multi-tenant or single-tenant @@ -41,14 +46,14 @@ Youu also need to decide whether the application is multi-tenant or single-tenan On the "Expose an API tab", you need to set the Application ID URI .. figure:: figures/App-Registration-App-ID.PNG - :alt: Overview of redirect URI configuration for local testing + :alt: Overview of app id URI An example configuration for api Scopes for testing an application and add scopes as configured for the application (e.g. the default ``openid`` scope is needed) .. figure:: figures/App-Registration-Scopes.PNG - :alt: Overview of redirect URI configuration for local testing + :alt: Overview of app scopes An example configuration for api Scopes for testing an application @@ -103,7 +108,7 @@ You can use it for fastapi routes:: router = APIRouter() @router.get('/hello') - async def hello_world(auth_state: AuthenticationState =D epends(auth_provider.api_auth_scheme)): + async def hello_world(auth_state: AuthenticationState =D epends(auth_provider.auth_backend.requires_auth(allow_session=True))): print(auth_state) return {'hello': 'world'} @@ -124,14 +129,7 @@ This middleware will set the request.user object and request.credentials object: return PlainTextResponse(f'Hello, you') -You can set the swagger_ui_init_oauth using auth_provider.api_auth_scheme.init_oauth:: - - from fastapi import FastAPI - app = FastAPI(... - swagger_ui_init_oauth=auth_provider.api_auth_scheme.init_oauth) - - -To add the required middleware to the fastapi app use:: +The :class:``fastapi.FastAPI`` ``swagger_ui_init_oauth`` variable is set automatically, along with the routing and required middleware using:: auth_provider.configure_app(app) diff --git a/src/fastapi_aad_auth/__init__.py b/src/fastapi_aad_auth/__init__.py index bc3e09a..3369999 100644 --- a/src/fastapi_aad_auth/__init__.py +++ b/src/fastapi_aad_auth/__init__.py @@ -1,6 +1,7 @@ -from fastapi_aad_auth.auth import AADAuth # noqa F401 +from fastapi_aad_auth.auth import Authenticator # noqa F401 from fastapi_aad_auth.config import Config # noqa F401 -from fastapi_aad_auth.oauth import AuthenticationState # noqa F401 -from ._version import get_versions +from fastapi_aad_auth._base.state import AuthenticationState # noqa F401 +from fastapi_aad_auth._version import get_versions + __version__ = get_versions()['version'] del get_versions diff --git a/src/fastapi_aad_auth/_base/__init__.py b/src/fastapi_aad_auth/_base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastapi_aad_auth/_base/authenticators/__init__.py b/src/fastapi_aad_auth/_base/authenticators/__init__.py new file mode 100644 index 0000000..293040a --- /dev/null +++ b/src/fastapi_aad_auth/_base/authenticators/__init__.py @@ -0,0 +1 @@ +from fastapi_aad_auth._base.authenticators.session import SessionAuthenticator # noqa: F401 diff --git a/src/fastapi_aad_auth/_base/authenticators/session.py b/src/fastapi_aad_auth/_base/authenticators/session.py new file mode 100644 index 0000000..6330301 --- /dev/null +++ b/src/fastapi_aad_auth/_base/authenticators/session.py @@ -0,0 +1,102 @@ +"""Base Session Authenticator for interactive (UI) sessions.""" +from starlette.requests import Request +from starlette.responses import RedirectResponse + +from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth.mixins import LoggingMixin + + +class SessionAuthenticator(LoggingMixin): + """Authenticator for interactive (UI) sessions.""" + + def __init__(self, session_validator, token_validator): + """Initialise the session authenticator.""" + self._session_validator = session_validator + self._token_validator = token_validator + super().__init__() + + def redirect_if_authenticated(self, auth_state, redirect='/'): + """Redirect to a target if authenticated.""" + if auth_state.is_authenticated(): + self.logger.info(f'Logged in, redirecting to {redirect}') + else: + redirect = '/login' + return RedirectResponse(redirect) + + def redirect_to_provider_login(self, auth_state, request): + """Redirect to the provider login.""" + self.logger.debug(f'state {auth_state}') + auth_state.save_to_session(self._session_validator._session_serializer, request.session) + authorization_url = self._get_authorization_url(request, auth_state.session_state) + return RedirectResponse(authorization_url) + + def _get_authorization_url(self, request, session_state): + raise NotImplementedError('Implement in specific subclass') + + def process_login_request(self, request, force=False, redirect='/'): + """Process the provider login request.""" + self.logger.debug(f'Logging in - request url {request.url}') + auth_state = self._session_validator.get_state_from_session(request) + if auth_state.is_authenticated() and not force: + self.logger.debug(f'Authenticated - redirecting {auth_state}') + response = self.redirect_if_authenticated(auth_state) + else: + # Set the redirect parameter here + self._session_validator.set_post_auth_redirect(request, request.query_params.get('redirect', redirect)) + self.logger.debug(f'No Auth state - redirecting to provider login {auth_state}') + response = self.redirect_to_provider_login(auth_state, request) + return response + + def process_login_callback(self, request): + """Process the provider login callback.""" + code = request.query_params.get('code', None) + state = request.query_params.get('state', None) + if state is None or code is None: + return # not authenticated + auth_state = self._session_validator.get_state_from_session(request) + auth_state.check_session_state(state) + token = self._process_code(request, auth_state, code) + user = self._get_user_from_token(token) + authenticated_state = AuthenticationState.authenticate_as(user, self._session_validator._session_serializer, request.session) + redirect = self._session_validator.pop_post_auth_redirect(request) + return self.redirect_if_authenticated(authenticated_state, redirect=redirect) + + def _process_code(self, request, auth_state, code): + raise NotImplementedError('Implement in subclass') + + def get_access_token(self, user): + """Get the access token for the user.""" + raise NotImplementedError('Implement in subclass') + + def get_access_token_from_request(self, request: Request): + """Get the access token from a request object.""" + auth_state = self._session_validator.get_state_from_session(request) + if auth_state.is_authenticated(): + return self.get_access_token(auth_state.user)['access_token'] + return None + + def get_user_from_request(self, request: Request): + """Get the user from a request object.""" + auth_state = self._session_validator.get_state_from_session(request) + return auth_state.user + + def _get_user_from_token(self, token, options=None): + validated_claims = self._token_validator.validate_token(token, options=options) + return self._token_validator._get_user_from_claims(validated_claims) + + def get_login_button(self, url, post_redirect='/'): + """Get a UI login button.""" + url = self._add_redirect_to_url(url, post_redirect) + return f'Sign in' + + def logout(self, request): + """Process a logout request if any special behaviour required.""" + pass + + def pop_post_auth_redirect(self, *args, **kwargs): + """Clear post-authentication redirects.""" + return self._session_validator.pop_post_auth_redirect(*args, **kwargs) + + def set_post_auth_redirect(self, *args, **kwargs): + """Set post-authentication redirects.""" + self._session_validator.set_post_auth_redirect(*args, **kwargs) diff --git a/src/fastapi_aad_auth/_base/backend.py b/src/fastapi_aad_auth/_base/backend.py new file mode 100644 index 0000000..85bd29a --- /dev/null +++ b/src/fastapi_aad_auth/_base/backend.py @@ -0,0 +1,89 @@ +"""Base OAuthBackend with token and session validators.""" +from typing import List, Optional + +from fastapi.security import OAuth2AuthorizationCodeBearer +from starlette.authentication import AuthCredentials, AuthenticationBackend, UnauthenticatedUser +from starlette.requests import Request + +from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth._base.validators import SessionValidator, TokenValidator, Validator +from fastapi_aad_auth.mixins import LoggingMixin, NotAuthenticatedMixin +from fastapi_aad_auth.utilities import deprecate + + +class BaseOAuthBackend(NotAuthenticatedMixin, LoggingMixin, AuthenticationBackend): + """Base OAuthBackend with token and session validators.""" + + def __init__(self, validators: List[Validator]): + """Initialise the validators.""" + super().__init__() + self.validators = validators[:] + + async def authenticate(self, request): + """Authenticate a request. + + Required by starlette authentication middleware + """ + state = self.check(request, allow_session=True) + if state is None: + return AuthCredentials([]), UnauthenticatedUser() + return state.credentials, state.authenticated_user + + def is_authenticated(self, request: Request): + """Check if a request is authenticated.""" + state = self.check(request, allow_session=True) + return state is not None and state.is_authenticated() + + async def __call__(self, request: Request) -> Optional[AuthenticationState]: + """Check/validate a request.""" + return self.check(request) + + def check(self, request: Request, allow_session=True) -> Optional[AuthenticationState]: + """Check/validate a request.""" + state = None + for validator in self.validators: + if not allow_session and isinstance(validator, SessionValidator): + self.logger.info('Skipping Session Validator as allow_session is False') + continue + state = validator.check(request) + self.logger.debug(f'Authentication state {state} from validator {validator}') + if state is not None: + break + self.logger.info(f'Identified state {state}') + return state + + def _iter_validators(self): + """Iterate over authentication validators.""" + for validator in self.validators: + yield validator + + def requires_auth(self, allow_session: bool = False): + """Require authentication, use with fastapi Depends.""" + # This is a bit horrible, but is needed for fastapi to get this into OpenAPI (or similar) - it needs to be an OAuth2 object + # We create this here "dynamically" for each endpoint, as we allow customisation on whether a session is permissible + + class OAuthValidator(OAuth2AuthorizationCodeBearer): + """OAuthValidator for API Auth.""" + + def __init__(self_): + """Initialise the validator.""" + token_validators = [u for u in self.validators if isinstance(u, TokenValidator)] + super().__init__(authorizationUrl=token_validators[0].model.flows.authorizationCode.authorizationUrl, + tokenUrl=token_validators[0].model.flows.authorizationCode.tokenUrl, + scopes=token_validators[0].model.flows.authorizationCode.scopes, + refreshUrl=token_validators[0].model.flows.authorizationCode.refreshUrl) + + async def __call__(self_, request: Request): + """Validate a request.""" + state = self.check(request, allow_session) + if state is None or not state.is_authenticated(): + raise self.not_authenticated + return state + + return OAuthValidator() + + @property # type: ignore + @deprecate('0.2.0', replaced_by=f'{__name__}:BaseOAuthBackend.requires_auth') + def api_auth_scheme(self): + """Get the API Authentication Schema.""" + return self.requires_auth() diff --git a/src/fastapi_aad_auth/_base/provider.py b/src/fastapi_aad_auth/_base/provider.py new file mode 100644 index 0000000..19e66e6 --- /dev/null +++ b/src/fastapi_aad_auth/_base/provider.py @@ -0,0 +1,73 @@ +from typing import List, Optional + +from starlette.requests import Request +from starlette.responses import RedirectResponse +from starlette.routing import Route + +from fastapi_aad_auth._base.authenticators import SessionAuthenticator +from fastapi_aad_auth._base.validators import Validator +from fastapi_aad_auth.mixins import LoggingMixin +from fastapi_aad_auth.utilities import urls + + +class Provider(LoggingMixin): + """Authentication Provider.""" + name: Optional[str] = None + + def __init__(self, validators: List[Validator], authenticator: SessionAuthenticator, enabled: bool = True, oauth_base_route: str = '/oauth'): + """Initialise the authentication provider.""" + self.validators = validators + self.authenticator = authenticator + self.enabled = enabled + self.oauth_base_route = oauth_base_route + self._login_url = None + self._redirect_url = None + super().__init__() + + def get_routes(self, noauth_redirect='/'): + """Get the authenticator routes.""" + + async def login(request: Request): + self.logger.debug(f'Logging in with {self.name} - request url {request.url}') + if self.enabled: + self.logger.debug(f'Auth {request.auth}') + return self.authenticator.process_login_request(request) + else: + self.logger.debug('Auth not enabled') + return RedirectResponse(noauth_redirect) + + async def login_callback(request: Request): + self.logger.info(f'Processing login callback for {self.name}') + self.logger.debug(f'request url {request.url}') + if self.enabled: + return self.authenticator.process_login_callback(request) + else: + self.logger.debug('Auth not enabled') + return RedirectResponse(noauth_redirect) + + routes = [Route(self.login_url, + endpoint=login, methods=['GET'], name=f'oauth_login_{self.name}'), + Route(self.redirect_url, + endpoint=login_callback, methods=['GET'], name=f'oauth_login_{self.name}_callback')] + return routes + + def _build_oauth_url(self, oauth_base_route, route): + return urls.append(oauth_base_route, self.name, route) + + def logout(self, request): + """Logout from the authenticator.""" + pass + + @property + def login_url(self): + """Get the login url.""" + if self._login_url is None: + self._login_url = self._build_oauth_url(self.oauth_base_route, 'login') + return self._login_url + + @property + def redirect_url(self): + """Get the login redirect url.""" + if self._redirect_url is None: + self._redirect_url = self._build_oauth_url(self.oauth_base_route, 'redirect') + return self._redirect_url diff --git a/src/fastapi_aad_auth/_base/state.py b/src/fastapi_aad_auth/_base/state.py new file mode 100644 index 0000000..9f9f177 --- /dev/null +++ b/src/fastapi_aad_auth/_base/state.py @@ -0,0 +1,133 @@ +"""Authentication State Handler.""" +from enum import Enum +import json +from typing import List, Optional +import uuid + +from itsdangerous import URLSafeSerializer +from itsdangerous.exc import BadSignature +from pydantic import BaseModel, root_validator +from starlette.authentication import AuthCredentials, AuthenticationError, SimpleUser, UnauthenticatedUser + +from fastapi_aad_auth.mixins import LoggingMixin + + +SESSION_STORE_KEY = 'auth' + + +class AuthenticationOptions(Enum): + """Authentication Options.""" + unauthenticated = 0 + not_allowed = -1 + authenticated = 1 + + +class User(BaseModel): + """User Model.""" + name: str + email: str + username: str + roles: Optional[List[str]] = None + groups: Optional[List[str]] = None + + @property + def permissions(self): + """User Permissions.""" + return [] + + +class AuthenticationState(LoggingMixin, BaseModel): + """Authentication State.""" + session_state: str = str(uuid.uuid4()) + state: AuthenticationOptions = AuthenticationOptions.unauthenticated + user: Optional[User] = None + _logger = None + + class Config: # noqa: D106 + underscore_attrs_are_private = True + + @root_validator(pre=True) + def _validate_user(cls, values): + if values.get('user', None) is None: + values['state'] = AuthenticationOptions.unauthenticated + return values + + def check_session_state(self, session_state): + """Check state against session state.""" + if session_state != self.session_state: + raise AuthenticationError("Session states do not match") + return True + + def store(self, serializer): + """Store in serializer.""" + return serializer.dumps(self.json()) + + @classmethod + def load(cls, serializer: URLSafeSerializer, encoded_state: Optional[str] = None): + """Load from encoded state. + + Args: + serializer: Serializer object containing the en/decoding secrets + Keyword Args: + encoded_state: The encoded state to be decoded + """ + if encoded_state: + try: + state = json.loads(serializer.loads(encoded_state)) + loaded_state = cls(**state) + except BadSignature: + loaded_state = cls() + else: + loaded_state = cls() + return loaded_state + + @classmethod + def logout(cls, serializer: URLSafeSerializer, session): + """Clear the sessions state.""" + state = cls.load_from_session(serializer, session) + state.user = None + state.state = AuthenticationOptions.unauthenticated + session[SESSION_STORE_KEY] = state.store(serializer) + + @classmethod + def load_from_session(cls, serializer: URLSafeSerializer, session): + """Load from a session.""" + return cls.load(serializer, session.get(SESSION_STORE_KEY, None)) + + def save_to_session(self, serializer: URLSafeSerializer, session): + """Save to a session.""" + session[SESSION_STORE_KEY] = self.store(serializer) + return session + + def is_authenticated(self): + """Check if the state is authenticated.""" + return self.user is not None and self.state == AuthenticationOptions.authenticated + + @property + def authenticated_user(self): + """Get the authenticated user.""" + if self.is_authenticated() and self.user: + if isinstance(self.user, User): + return SimpleUser(self.user.email) + return UnauthenticatedUser() + + @property + def credentials(self): + """Get the credentials object.""" + if self.user and self.is_authenticated(): + return AuthCredentials(['authenticated'] + self.user.permissions) + else: + return AuthCredentials() + + @classmethod + def authenticate_as(cls, user, serializer, session): + """Store the authenticated user.""" + state = cls(user=user, state=AuthenticationOptions.authenticated) + if serializer is not None and session is not None: + state.save_to_session(serializer, session) + return state + + @classmethod + def as_unauthenticated(cls, serializer, session): + """Store as an un-authenticated user.""" + return cls.authenticate_as(None, serializer, session) diff --git a/src/fastapi_aad_auth/_base/validators/__init__.py b/src/fastapi_aad_auth/_base/validators/__init__.py new file mode 100644 index 0000000..7a29b7a --- /dev/null +++ b/src/fastapi_aad_auth/_base/validators/__init__.py @@ -0,0 +1,4 @@ + +from fastapi_aad_auth._base.validators.base import Validator # noqa: F401 +from fastapi_aad_auth._base.validators.session import SessionValidator # noqa: F401 +from fastapi_aad_auth._base.validators.token import TokenValidator # noqa: F401 diff --git a/src/fastapi_aad_auth/_base/validators/base.py b/src/fastapi_aad_auth/_base/validators/base.py new file mode 100644 index 0000000..16954ea --- /dev/null +++ b/src/fastapi_aad_auth/_base/validators/base.py @@ -0,0 +1,26 @@ +from abc import abstractmethod + +from starlette.requests import Request + +from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth.mixins import LoggingMixin, NotAuthenticatedMixin + + +class Validator(NotAuthenticatedMixin, LoggingMixin): + """Base Validator Class.""" + + @abstractmethod + def check(self, request: Request) -> AuthenticationState: + """Check a request.""" + raise NotImplementedError('Implement in subclass') + + async def __call__(self, request: Request) -> AuthenticationState: # type: ignore + """Validate the request authentication. + + Returns an AuthenticationState object or raises an Unauthorized error + """ + result = self.check(request) + self.logger.info(f'Identified state {result}') + if not result.is_authenticated(): + raise self.not_authenticated + return result diff --git a/src/fastapi_aad_auth/oauth/validators/session.py b/src/fastapi_aad_auth/_base/validators/session.py similarity index 68% rename from src/fastapi_aad_auth/oauth/validators/session.py rename to src/fastapi_aad_auth/_base/validators/session.py index e199ab3..c3cd58a 100644 --- a/src/fastapi_aad_auth/oauth/validators/session.py +++ b/src/fastapi_aad_auth/_base/validators/session.py @@ -1,22 +1,14 @@ -"""Validator for interactive (UI) sessions.""" - -import logging - +"""Session based validator for interactive (UI) sessions.""" from itsdangerous import URLSafeSerializer -from fastapi_aad_auth.oauth.state import AuthenticationState +from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth._base.validators import Validator -logger = logging.getLogger(__name__) REDIRECT_KEY = 'requested' -def get_session_serializer(secret, salt): - """Get or Initialise the session serializer.""" - return URLSafeSerializer(secret, salt=salt) - - -class SessionValidator: +class SessionValidator(Validator): """Validator for session based authentication.""" def __init__(self, session_serializer: URLSafeSerializer, *args, **kwargs): @@ -35,7 +27,7 @@ def check(self, request): state = AuthenticationState.load_from_session(self._session_serializer, request.session) except Exception: state = AuthenticationState.as_unauthenticated(self._session_serializer, request.session) - logger.exception('Error authenticating via session') + self.logger.exception('Error authenticating via session') return state def pop_post_auth_redirect(self, request): @@ -45,3 +37,12 @@ def pop_post_auth_redirect(self, request): def set_post_auth_redirect(self, request, redirect='/'): """Set post-authentication redirects.""" request.session[REDIRECT_KEY] = redirect + + @staticmethod + def get_session_serializer(secret, salt): + """Get or Initialise the session serializer.""" + return URLSafeSerializer(secret, salt=salt) + + def logout(self, request): + """Process a logout request.""" + AuthenticationState.logout(self._session_serializer, request.session) diff --git a/src/fastapi_aad_auth/_base/validators/token.py b/src/fastapi_aad_auth/_base/validators/token.py new file mode 100644 index 0000000..5229a4e --- /dev/null +++ b/src/fastapi_aad_auth/_base/validators/token.py @@ -0,0 +1,113 @@ +"""Base validator for token based authentication.""" +from authlib.jose import errors as jwt_errors +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.security.utils import get_authorization_scheme_param +from pydantic import BaseModel +from starlette.middleware.authentication import AuthenticationError +from starlette.requests import Request + +from fastapi_aad_auth._base.state import AuthenticationState, User +from fastapi_aad_auth._base.validators.base import Validator + + +class InitOAuth(BaseModel): + """OAuth information for openapi docs.""" + clientId: str + scopes: str + usePkceWithAuthorizationCodeGrant: bool + + +class TokenValidator(Validator, OAuth2AuthorizationCodeBearer): # type: ignore + """Validator for token based authentication.""" + + def __init__( + self, + client_id: str, + authorizationUrl: str, + tokenUrl: str, + api_audience: str = None, + scheme_name: str = None, + scopes: dict = None, + auto_error: bool = False, + enabled: bool = True, + use_pkce: bool = True, + user_klass: type = User + ): + """Initialise validator for token based authentication.""" + super().__init__(authorizationUrl=authorizationUrl, tokenUrl=tokenUrl, refreshUrl=api_audience, scheme_name=scheme_name, scopes=scopes, auto_error=auto_error) + self.client_id = client_id + self.enabled = enabled + if api_audience is None: + api_audience = f"api://{client_id}" + self.api_audience = api_audience + self._use_pkce = use_pkce + self._user_klass = user_klass + + def check(self, request: Request): + """Check the authentication from the request.""" + token = self.get_token(request) + if token is None: + return AuthenticationState.as_unauthenticated(None, None) + claims = self.validate_token(token) + user = self._get_user_from_claims(claims) + return AuthenticationState.authenticate_as(user, None, None) + + def get_token(self, request: Request): + """Get the token from the request.""" + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise self.not_authenticated + else: + return None # pragma: nocover + return param + + @property + def init_oauth(self): + """Get the openapi docs config.""" + return InitOAuth(clientId=self.client_id, scopes=f'{self.api_audience}/openid', usePkceWithAuthorizationCodeGrant=self._use_pkce).dict() + + def _validate_claims(self, claims, options=None): + if options is None: + options = self._claims_options + claims.options = options + try: + claims.validate() + except jwt_errors.ExpiredTokenError as e: + self.logger.error(f'Expired token:\n\t{self._compare_claims(claims)}') + raise AuthenticationError(f"Token is expired {e.args}") + except jwt_errors.InvalidClaimError as e: + self.logger.error(f'Invalid claims:\n\t{self._compare_claims(claims)}') + raise AuthenticationError(f"Invalid claims {e.args}") + except jwt_errors.MissingClaimError as e: + self.logger.error(f'Missing claims:\n\t{self._compare_claims(claims)}') + raise AuthenticationError(f"Missing claims {e.args}") + except Exception as e: + self.logger.exception('Unable to parse error') + raise AuthenticationError(f"Unable to parse authentication token {e.args}") + return claims + + @property + def _claims_options(self): + options = {"sub": {"essential": True}, + "aud": {"essential": True, "values": [self.api_audience]}, + "exp": {"essential": True}, + "nbf": {"essential": True}, + "iat": {"essential": True}} + return options + + def _decode_token(self, token): + raise NotImplementedError('Implement in base class') + + def validate_token(self, token, options=None): + """Validate provided token.""" + claims = self._decode_token(token) + return self._validate_claims(claims, options) + + @staticmethod + def _compare_claims(claims): + return '\n\t'.join([f'{key}: {value} - {claims.options.get(key, None)}' for key, value in claims.items()]) + + def _get_user_from_claims(self, claims): + raise NotImplementedError('Implement in sub class') diff --git a/src/fastapi_aad_auth/auth.py b/src/fastapi_aad_auth/auth.py index 23fc393..4511dd4 100644 --- a/src/fastapi_aad_auth/auth.py +++ b/src/fastapi_aad_auth/auth.py @@ -1,8 +1,7 @@ -"""Base AAD Authentication Handler.""" +"""Authenticator Class.""" from functools import wraps -import logging from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from fastapi import Depends, FastAPI from starlette.authentication import requires @@ -14,112 +13,189 @@ from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates +from fastapi_aad_auth._base.backend import BaseOAuthBackend +from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth._base.validators import SessionValidator from fastapi_aad_auth.config import Config from fastapi_aad_auth.errors import base_error_handler, ConfigurationError -from fastapi_aad_auth.oauth import AADOAuthBackend, AuthenticationState +from fastapi_aad_auth.mixins import LoggingMixin +from fastapi_aad_auth.utilities import deprecate -logger = logging.getLogger(__name__) - _BASE_ROUTES = ['openapi', 'swagger_ui_html', 'swagger_ui_redirect', 'redoc_html'] -class AADAuth: - """AAD Authenticator Class. - - Generates and handles adding AAD authentication, routing and middleware +class Authenticator(LoggingMixin): + """Authenticator class. - Includes a decorator for signifying authentication required on fastapi routes, and a basic Login UI with AAD link + Creates the key components based on the provided configurations """ - def __init__(self, config: Config = None, add_to_base_routes: bool = True): + def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_context: Dict[str, Any] = None, user_klass: Optional[type] = None): """Initialise the AAD config based on the provided configuration. Keyword Args: config (fastapi_aad_auth.config.Config): Authentication configuration (includes ui and routing, as well as AAD Application and Tenant IDs) add_to_base_routes (bool): Add the authentication to the router """ + super().__init__() if config is None: config = Config() + if user_klass is not None: + config.user_klass = user_klass self.config = config - self.oauth_backend = AADOAuthBackend.from_config(self.config) - if add_to_base_routes: - self._add_to_base_routes = True - - def app_routes_add_auth(self, app: FastAPI, route_list: List[str], invert: bool = False): - """Add authentication to specified routes in application router. + if base_context is None: + base_context = {} + if self.config.login_ui.context: + context = self.config.login_ui.context.copy() + context.update(base_context) + base_context = context + self._base_context = base_context + self._add_to_base_routes = add_to_base_routes + self._session_validator = self._init_session_validator() + self._providers = self._init_providers() + self.auth_backend = self._init_auth_backend() + self._ui_routes = self._init_ui() + self._auth_routes = self._init_auth_routes() + + def _init_session_validator(self): + auth_serializer = SessionValidator.get_session_serializer(self.config.auth_session.secret.get_secret_value(), + self.config.auth_session.salt.get_secret_value()) + return SessionValidator(auth_serializer) + # Lets setup the oauth backend + + def _init_providers(self): + return [u._provider_klass.from_config(session_validator=self._session_validator, config=self.config, provider_config=u) for u in self.config.providers] + + def _init_auth_backend(self): + validators = [self._session_validator] + for provider in self._providers: + validators += provider.validators + return BaseOAuthBackend(validators) + + def _init_ui(self): + login_template_path = Path(self.config.login_ui.template_file) + user_template_path = Path(self.config.login_ui.user_template_file) + login_templates = Jinja2Templates(directory=str(login_template_path.parent)) + user_templates = Jinja2Templates(directory=str(user_template_path.parent)) - Used for default routes (e.g. api/docs and api/redocs, openapi.json etc) + async def login(request: Request, *args, **kwargs): + context = self._base_context.copy() # type: ignore + if not self.config.enabled or request.user.is_authenticated: + # This is authenticated so go straight to the homepage + return RedirectResponse(self.config.routing.home_path) + context['request'] = request # type: ignore + if 'login' not in context or context['login'] is None: # type: ignore + post_redirect = self._session_validator.pop_post_auth_redirect(request) + context['login'] = '
'.join([provider.get_login_button(post_redirect) for provider in self._providers]) # type: ignore + return login_templates.TemplateResponse(login_template_path.name, context) # type: ignore - Args: - app: fastapi application - route_list: list of routes to add authentication to (e.g. api docs, redocs etc) + routes = [Route(self.config.routing.landing_path, endpoint=login, methods=['GET'], name='login'), + Mount(self.config.login_ui.static_path, StaticFiles(directory=str(self.config.login_ui.static_directory)), name='static-login')] - Keyword Args: - invert: Switch between using the route list as a block list or an allow list + if self.config.routing.user_path: - """ - if self.oauth_backend.enabled: - routes = app.router.routes - for i, route in enumerate(routes): - # Can use allow list or block list (i.e. invert = True sets all except the route list to have auth - if (route.name in route_list and not invert) or (route.name not in route_list and invert): # type: ignore - route.endpoint = self.auth_required()(route.endpoint) # type: ignore - route.app = request_response(route.endpoint) # type: ignore - app.router.routes[i] = route - return app + @self.auth_required() + async def get_user(request: Request): + context = self._base_context.copy() # type: ignore + self.logger.debug(f'Getting token for {request.user}') + context['request'] = request # type: ignore + if self.config.enabled: + self.logger.debug(f'Auth {request.auth}') + try: + context['user'] = self._session_validator.get_state_from_session(request).user + except ValueError: + # If we have one provider, we can force the login, otherwise... + return self.__force_authenticate(request) + else: + self.logger.debug('Auth not enabled') + context['token'] = None # type: ignore + return user_templates.TemplateResponse(user_template_path.name, context) - def configure_app(self, app: FastAPI): - """Configure the fastapi application to use these authentication handlers. + async def get_token(request: Request, auth_state: AuthenticationState = Depends(self.auth_backend)): + if not isinstance(auth_state, AuthenticationState): + user = self.__get_user_from_request(request) + else: + user = auth_state.user + if hasattr(user, 'username'): # type: ignore + access_token = self.__get_access_token(user) + if access_token: + # We want to get the token for each provider that is authenticated + return JSONResponse(access_token) # type: ignore + else: + if any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]): + # If we have one provider, we can force the login, otherwise we need to request which login route + return self.__force_authenticate(request) + else: + return JSONResponse('Unable to access token as user has not authenticated via session') + return RedirectResponse(f'{self.config.routing.landing_path}?redirect=/me/token') - Adds authentication middleware, error handler and adds authnetication - to the default routes as well as adding the authentication specific routes + routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'), + Route(f'{self.config.routing.user_path}/token', endpoint=get_token, methods=['GET'], name='get-token')] + return routes - Args: - app: fastapi application - """ + def _init_auth_routes(self): - def on_auth_error(request: Request, exc: Exception): - logger.exception(f'Error {exc} for request {request}') - self.oauth_backend.authenticator.set_post_auth_redirect(request, request.url.path) - return RedirectResponse(self.config.routing.landing_path) + async def logout(request: Request): + self.logger.debug(f'Logging out - request url {request.url}') + if self.config.enabled: + self.logger.debug(f'Auth {request.auth}') + for provider in self._providers: + provider.logout(request) + self._session_validator.logout(request) + return RedirectResponse(self.config.routing.post_logout_path) + routes = [Route(self.config.routing.logout_path, endpoint=logout, methods=['GET'], name='logout')] + for provider in self._providers: + routes += provider.get_routes(noauth_redirect=self.config.routing.home_path) + # We have a deprecated behaviour here + return routes - app.add_middleware(AuthenticationMiddleware, backend=self.oauth_backend, on_error=on_auth_error) + def __force_authenticate(self, request: Request): + if len(self._providers) == 1: + return self._providers[0].authenticator.process_login_request(request, force=True, redirect=request.url.path) + else: + return RedirectResponse(f'{self.config.routing.landing_path}?redirect={request.url.path}') + + def __get_access_token(self, user): + access_token = None + while access_token is None: + provider = next(self._providers) + try: + access_token = provider.autheticator.get_access_token(user) + except ValueError: + pass + return access_token + + def __get_user_from_request(self, request: Request): + if hasattr(request.user, 'username'): + user = request.user + else: + auth_state = self.auth_backend.check(request) + user = auth_state.user + return user + def _set_error_handlers(self, app): error_template_path = Path(self.config.login_ui.error_template_file) error_templates = Jinja2Templates(directory=str(error_template_path.parent)) - - if self.config.login_ui.context: - context = self.config.login_ui.context - else: - context = {} if self.config.login_ui.app_name: - context['appname'] = self.config.login_ui.app_name + self._base_context['appname'] = self.config.login_ui.app_name else: - context['appname'] = app.title - context['static_path'] = self.config.login_ui.static_path + self._base_context['appname'] = app.title + self._base_context['static_path'] = self.config.login_ui.static_path @app.exception_handler(ConfigurationError) async def configuration_error_handler(request: Request, exc: ConfigurationError) -> Response: error_message = "Oops! It seems like the application has not been configured correctly, please contact an admin" error_type = 'Authentication Configuration Error' status_code = 500 - return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=context.copy(), status_code=status_code) + return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) @app.exception_handler(AuthenticationError) async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response: error_message = "Oops! It seems like you cannot access this information. If this is an error, please contact an admin" error_type = 'Authentication Error' status_code = 403 - return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=context.copy(), status_code=status_code) - - # Check if session middleware is there - if not any([SessionMiddleware in u.cls.__mro__ for u in app.user_middleware]): - app.add_middleware(SessionMiddleware, **self.config.session.dict()) - if self._add_to_base_routes: - self.app_routes_add_auth(app, _BASE_ROUTES) - app.routes.extend(self.auth_routes) - app.routes.extend(self.build_auth_ui(context)) + return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) def auth_required(self, scopes: str = 'authenticated', redirect: str = 'login'): """Decorator to require specific scopes (and redirect to the login ui) for an endpoint. @@ -137,7 +213,7 @@ def wrapper(endpoint): @wraps(endpoint) async def require_endpoint(request: Request, *args, **kwargs): - self.oauth_backend.authenticator.set_post_auth_redirect(request, request.url.path) + self._session_validator.set_post_auth_redirect(request, request.url.path) @requires(scopes, redirect=redirect) async def req_wrapper(request: Request, *args, **kwargs): @@ -151,118 +227,70 @@ async def req_wrapper(request: Request, *args, **kwargs): return wrapper - @property - def auth_routes(self): - """Get the default authentication routes and methods. - - Includes login, logout and the login callback - """ + def app_routes_add_auth(self, app: FastAPI, route_list: List[str], invert: bool = False): + """Add authentication to specified routes in application router. - async def logout(request: Request): - logger.debug(f'Logging out - request url {request.url}') - if self.oauth_backend.enabled: - logger.debug(f'Auth {request.auth}') - self.oauth_backend.authenticator.logout(request) - return RedirectResponse(self.config.routing.post_logout_path) + Used for default routes (e.g. api/docs and api/redocs, openapi.json etc) - async def login(request: Request): - logger.debug(f'Logging in - request url {request.url}') - if self.oauth_backend.enabled: - logger.debug(f'Auth {request.auth}') - return self.oauth_backend.authenticator.process_login_request(request) - else: - logger.debug('Auth not enabled') - return RedirectResponse(self.config.routing.home_path) + Args: + app: fastapi application + route_list: list of routes to add authentication to (e.g. api docs, redocs etc) - async def login_callback(request: Request): - logger.info('Processing login callback from Azure AD') - logger.debug(f'request url {request.url}') - if self.oauth_backend.enabled: - return self.oauth_backend.authenticator.process_login_callback(request) - else: - logger.debug('Auth not enabled') - return RedirectResponse(self.config.routing.landing_path) + Keyword Args: + invert: Switch between using the route list as a block list or an allow list - routes = [Route(self.config.routing.logout_path, endpoint=logout, methods=['GET'], name='logout'), - Route(self.config.routing.login_path, endpoint=login, methods=['GET'], name='login_oauth'), - Route(self.config.routing.login_redirect_path, endpoint=login_callback, methods=['GET'], name='login_callback')] + """ + if self.config.enabled: + routes = app.router.routes + for i, route in enumerate(routes): + # Can use allow list or block list (i.e. invert = True sets all except the route list to have auth + if (route.name in route_list and not invert) or (route.name not in route_list and invert): # type: ignore + route.endpoint = self.auth_required()(route.endpoint) # type: ignore + route.app = request_response(route.endpoint) # type: ignore + app.router.routes[i] = route + return app - return routes + def configure_app(self, app: FastAPI, add_error_handlers=True): + """Configure the fastapi application to use these authentication handlers. - def build_auth_ui(self, context: Dict[str, Any] = None): - """Build the ui route and static data for the login UI. + Adds authentication middleware, error handler and adds authentication + to the default routes as well as adding the authentication specific routes - The context kwargs can include ``login`` - button HTML (different to the default Microsoft UI button), - ``appname`` - the application name (for the login page title) + Args: + app: fastapi application Keyword Args: - contex: a dicitionary of predefined parameters to pass to the Jinja2 Login UI template + add_error_handlers (bool) : add the error handlers to the app (default is true, but can be set to False to configure specific handling) """ - login_template_path = Path(self.config.login_ui.template_file) - user_template_path = Path(self.config.login_ui.user_template_file) - login_templates = Jinja2Templates(directory=str(login_template_path.parent)) - user_templates = Jinja2Templates(directory=str(user_template_path.parent)) - if context is None: - context = {} - async def login(request: Request, *args, **kwargs): - nonlocal context - view_context = context.copy() # type: ignore - if not self.oauth_backend.enabled or request.user.is_authenticated: - # This is authenticated so go straight to the homepage - return RedirectResponse(self.config.routing.home_path) - view_context['request'] = request # type: ignore - if 'login' not in view_context or view_context['login'] is None: # type: ignore - post_redirect = self.oauth_backend.authenticator.pop_post_auth_redirect(request) - view_context['login'] = self.oauth_backend.authenticator.get_login_button(self.config.routing.login_path, post_redirect) # type: ignore - return login_templates.TemplateResponse(login_template_path.name, view_context) # type: ignore + def on_auth_error(request: Request, exc: Exception): + self.logger.exception(f'Error {exc} for request {request}') + self._session_validator.set_post_auth_redirect(request, request.url.path) + return RedirectResponse(self.config.routing.landing_path) - routes = [Route(self.config.routing.landing_path, endpoint=login, methods=['GET'], name='login'), - Mount(self.config.login_ui.static_path, StaticFiles(directory=str(self.config.login_ui.static_directory)), name='static-login')] + app.add_middleware(AuthenticationMiddleware, backend=self. auth_backend, on_error=on_auth_error) + if add_error_handlers: + self._set_error_handlers(app) + # Check if session middleware is there + if not any([SessionMiddleware in u.cls.__mro__ for u in app.user_middleware]): + app.add_middleware(SessionMiddleware, **self.config.session.dict()) + if self._add_to_base_routes: + self.app_routes_add_auth(app, _BASE_ROUTES) + app.routes.extend(self._ui_routes) + app.routes.extend(self._auth_routes) + # TODO: select a specific provider to use here + app.swagger_ui_init_oauth = self._providers[0].validators[0].init_oauth - if self.config.routing.user_path: - @self.auth_required() - async def get_user(request: Request): - nonlocal context - view_context = context.copy() # type: ignore - logger.debug(f'Getting token for {request.user}') - view_context['request'] = request - if self.oauth_backend.enabled: - logger.debug(f'Auth {request.auth}') - try: - view_context['user'] = self.oauth_backend.authenticator.get_user_from_request(request) - view_context['token'] = self.oauth_backend.authenticator.get_access_token(view_context['user']) - except ValueError: - return self.oauth_backend.authenticator.process_login_request(request, force=True, redirect=request.url.path) - else: - logger.debug('Auth not enabled') - view_context['token'] = None - return user_templates.TemplateResponse(user_template_path.name, view_context) - - async def get_token(request: Request, auth_state: AuthenticationState = Depends(self.api_auth_scheme)): - if not isinstance(auth_state, AuthenticationState): - if hasattr(request.user, 'username'): - user = request.user - else: - auth_state = await self.api_auth_scheme(request) - user = auth_state.user - if hasattr(user, 'username'): # type: ignore - try: - return JSONResponse(self.oauth_backend.authenticator.get_access_token(user)) # type: ignore - except ValueError: - if any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]): - return self.oauth_backend.authenticator.process_login_request(request, force=True, redirect=request.url.path) - else: - return JSONResponse('Unable to access token as user has not authenticated via session') - return RedirectResponse(f'{self.config.routing.landing_path}?redirect=/me/token') +_DEPRECATED_VERSION = '0.2.0' - routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'), - Route(f'{self.config.routing.user_path}/token', endpoint=get_token, methods=['GET'], name='get-token')] - return routes +@deprecate(_DEPRECATED_VERSION, replaced_by=f'{Authenticator.__module__}:{Authenticator.__name__}') +class AADAuth(Authenticator): # noqa: D101 + __doc__ = Authenticator.__doc__ - @property + @property # type: ignore + @deprecate(_DEPRECATED_VERSION, replaced_by=f'{Authenticator.__module__}:{Authenticator.__name__}.auth_backend.requires_auth') def api_auth_scheme(self): - """Get the authentication scheme for the api page.""" - return self.oauth_backend.api_auth_scheme + """Get the API Authentication Schema.""" + return self.auth_backend.requires_auth() diff --git a/src/fastapi_aad_auth/config.py b/src/fastapi_aad_auth/config.py index 343f616..96c995f 100644 --- a/src/fastapi_aad_auth/config.py +++ b/src/fastapi_aad_auth/config.py @@ -1,51 +1,20 @@ """fastapi_aad_auth configuration options.""" import importlib -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import uuid from pkg_resources import resource_filename -from pydantic import BaseSettings, DirectoryPath, Field, FilePath, HttpUrl, SecretStr, validator - - -def bool_from_env(env_value): - """Convert environment variable to boolean.""" - if isinstance(env_value, str): - env_value = env_value.lower() in ['true', '1'] - return env_value - - -def list_from_env(env_value): - """Convert environment variable to list.""" - if isinstance(env_value, str): - env_value = [u for u in env_value.split(',') if u] - return env_value - - -def expand_doc(klass): - """Expand pydantic model documentation to enable autodoc.""" - docs = ['', '', 'Keyword Args:'] - for name, field in klass.__fields__.items(): - default_str = '' - if field.default: - default_str = f' [default: ``{field.default}``]' - module = field.outer_type_.__module__ - if module != 'builtins': - if hasattr(field.outer_type_, '__origin__'): - type_ = f' ({field.outer_type_.__origin__.__name__}) ' - elif not hasattr(field.outer_type_, '__name__'): - type_ = '' - else: - type_ = f' ({module}.{field.outer_type_.__name__}) ' - else: - type_ = f' ({field.outer_type_.__name__}) ' - env_var = '' - if 'env' in field.field_info.extra: - env_var = f' (Can be set by ``{field.field_info.extra["env"]}`` environment variable)' - docs.append(f' {name}{type_}: {field.field_info.description}{default_str}{env_var}') - if klass.__doc__ is None: - klass.__doc__ = '' - klass.__doc__ += '\n'.join(docs) - return klass +from pydantic import BaseSettings as _BaseSettings, DirectoryPath, Field, FilePath, SecretStr, validator + +from fastapi_aad_auth.providers.aad import AADConfig +from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, DeprecatedField, expand_doc + + +class BaseSettings(DeprecatableFieldsMixin, _BaseSettings): + """Allow deprecations in the BaseSettings object.""" + + +_DEPRECATION_VERSION = '0.2.0' @expand_doc @@ -59,8 +28,10 @@ class RoutingConfig(BaseSettings): page (defaults to the application root), and the ``post_logout_path`` for any specific routing once a logout has completed. """ - login_path: str = Field('/login/oauth', description="Path for initiating the AAD oauth call", env='FASTAPI_AUTH_LOGIN_ROUTE') - login_redirect_path: str = Field('/login/oauth/redirect', description="Path for handling the AAD redirect call", env='FASTAPI_AUTH_LOGIN_REDIRECT_ROUTE') + + login_path: str = DeprecatedField('/login/oauth', description="Path for initiating the AAD oauth call", env='FASTAPI_AUTH_LOGIN_ROUTE', deprecated_in=_DEPRECATION_VERSION, replaced_by='Routing.oauth_base_route', additional_info=' - To access the new behaviour, set this value to None or an empty string') + login_redirect_path: str = DeprecatedField('/login/oauth/redirect', description="Path for handling the AAD redirect call", env='FASTAPI_AUTH_LOGIN_REDIRECT_ROUTE', deprecated_in=_DEPRECATION_VERSION, replaced_by='Routing.oauth_base_route', additional_info=' - To access the new behaviour, set this value to None or an empty string') + oauth_base_route: str = Field('/oauth', description="Base Path for initiating the oauth calls", env='FASTAPI_OAUTH_BASE_ROUTE') logout_path: str = Field('/logout', description="Path for processing a logout request", env='FASTAPI_AUTH_LOGOUT_ROUTE') landing_path: str = Field('/login', description="Path for the login UI page", env='FASTAPI_AUTH_LOGIN_UI_ROUTE') user_path: Optional[str] = Field('/me', description="Path for getting the user view", env='FASTAPI_AUTH_USER_ROUTE') @@ -108,40 +79,6 @@ class Config: # noqa D106 env_file = '.env' -@expand_doc -class AADConfig(BaseSettings): - """Configuration for the AAD application. - - Includes expected claims, application registration, etc. - - Can also provide additional client application ids to accept. - - A list of roles can be provided to accept (requires configuring the - roles in the AAD application registration manifest) - """ - client_id: SecretStr = Field(..., description="Application Registration Client ID", env='AAD_CLIENT_ID') - tenant_id: SecretStr = Field(..., description="Application Registration Tenant ID", env='AAD_TENANT_ID') - client_secret: Optional[SecretStr] = Field(None, description="Application Registration Client Secret (if required)", env='AAD_CLIENT_SECRET') - scopes: List[str] = Field(["Read"], description="Additional scopes requested") - client_app_ids: Optional[List[str]] = Field(None, description="Additional Client App IDs to accept tokens from (when running as a backend service)", - env='AAD_CLIENT_APP_IDS') - strict: bool = Field(True, description="Check that all claims are provided", env='AAD_STRICT_CLAIM_CHECK') - api_audience: Optional[str] = Field(None, description="Corresponds to the Application ID URI - used for token validation, defaults to api://{client_id}", - env='AAD_API_AUDIENCE') - redirect_uri: Optional[HttpUrl] = Field(None, description="The redirect URI to use - overwrites the default path handling etc", - env='AAD_REDIRECT_URI') - prompt: Optional[str] = Field(None, description="AAD prompt to request", env='AAD_PROMPT') - domain_hint: Optional[str] = Field(None, description="AAD domain hint", env='AAD_DOMAIN_HINT') - roles: Optional[List[str]] = Field(None, description="AAD roles required in claims", env='AAD_ROLES') - - class Config: # noqa D106 - env_file = '.env' - - _validate_strict = validator('strict', allow_reuse=True)(bool_from_env) - _validate_client_app_ids = validator('client_app_ids', allow_reuse=True)(list_from_env) - _validate_roles = validator('roles', allow_reuse=True)(list_from_env) - - @expand_doc class AuthSessionConfig(BaseSettings): """Authentication Session configuration. @@ -188,43 +125,51 @@ class Config(BaseSettings): """The overall configuration for the AAD authentication.""" enabled: bool = Field(True, description="Enable authentication", env='FASTAPI_AUTH_ENABLED') - aad: AADConfig = Field(None, description="The AAD configuration to use") + providers: List[Union[AADConfig]] = Field(None, description="The provider configurations to use") + aad: AADConfig = DeprecatedField(None, description='AAD Configuration information', deprecated_in='0.2.0', replaced_by='Config.providers') auth_session: AuthSessionConfig = Field(None, description="The configuration for encoding the authentication information in the session") routing: RoutingConfig = Field(None, description="Configuration for routing") session: SessionConfig = Field(None, description="Configuration for the session middleware") login_ui: LoginUIConfig = Field(None, description="Login UI Configuration") - user_klass: type = Field('fastapi_aad_auth.oauth.state:User', + user_klass: type = Field('fastapi_aad_auth._base.state:User', description="User class to use within the AADOAuthBackend, this will be treated as an import path " "if provided as a string, with the last part the class to load", env='FASTAPI_AUTH_USER_KLASS') class Config: # noqa D106 env_file = '.env' - @validator('aad') - def _validate_aad(cls, value): + @validator('providers', always=True, pre=True) + def _validate_providers(cls, value): + if value is None: + value = [AADConfig(_env_file=cls.Config.env_file)] + return value + + @validator('aad', always=True, pre=True) + def _validate_aad(cls, value, values): if value is None: - value = AADConfig(_env_file=cls.Config.env_file) + providers = values.get('providers', [AADConfig(_env_file=cls.Config.env_file)]) + value = [u for u in providers if isinstance(u, AADConfig)][0] return value - @validator('auth_session') + @validator('auth_session', always=True, pre=True) def _validate_auth_session(cls, value): if value is None: value = AuthSessionConfig(_env_file=cls.Config.env_file) return value - @validator('routing') + @validator('routing', always=True, pre=True) def _validate_routing(cls, value): if value is None: value = RoutingConfig(_env_file=cls.Config.env_file) return value - @validator('session') + @validator('session', always=True, pre=True) def _validate_session(cls, value): if value is None: value = SessionConfig(_env_file=cls.Config.env_file) return value - @validator('login_ui') + @validator('login_ui', always=True, pre=True) def _validate_login_ui(cls, value): if value is None: value = LoginUIConfig(_env_file=cls.Config.env_file) diff --git a/src/fastapi_aad_auth/errors.py b/src/fastapi_aad_auth/errors.py index 0bb2419..b57558c 100644 --- a/src/fastapi_aad_auth/errors.py +++ b/src/fastapi_aad_auth/errors.py @@ -1,10 +1,10 @@ """fastapi_aad_auth errors.""" -import logging - from starlette.responses import JSONResponse, Response -logger = logging.getLogger(__name__) +from fastapi_aad_auth.utilities.logging import getLogger + +logger = getLogger(__name__) def base_error_handler(request, exception, error_type, error_message, templates, template_path, context=None, status_code=500) -> Response: diff --git a/src/fastapi_aad_auth/mixins/__init__.py b/src/fastapi_aad_auth/mixins/__init__.py new file mode 100644 index 0000000..58faf93 --- /dev/null +++ b/src/fastapi_aad_auth/mixins/__init__.py @@ -0,0 +1,3 @@ +"""Mixins for use in objects.""" +from fastapi_aad_auth.mixins.logging import LoggingMixin # noqa F401 +from fastapi_aad_auth.mixins.not_authenticated import NotAuthenticatedMixin # noqa F401 diff --git a/src/fastapi_aad_auth/mixins/logging.py b/src/fastapi_aad_auth/mixins/logging.py new file mode 100644 index 0000000..78e9e35 --- /dev/null +++ b/src/fastapi_aad_auth/mixins/logging.py @@ -0,0 +1,18 @@ +"""Add logger to a class.""" + +from fastapi_aad_auth.utilities import logging + + +class LoggingMixin: + """Add logger to class based on name.""" + def __init__(self, *args, **kwargs): + """Initialise the logger.""" + self._logger = None + super().__init__(*args, **kwargs) + + @property + def logger(self): + """Get the logger object.""" + if self._logger is None: + self._logger = logging.getLogger(self.__class__.__name__) + return self._logger diff --git a/src/fastapi_aad_auth/mixins/not_authenticated.py b/src/fastapi_aad_auth/mixins/not_authenticated.py new file mode 100644 index 0000000..74d0996 --- /dev/null +++ b/src/fastapi_aad_auth/mixins/not_authenticated.py @@ -0,0 +1,16 @@ +"""Add not_authenticated error to a class.""" +from fastapi import HTTPException +from starlette.status import HTTP_401_UNAUTHORIZED + + +class NotAuthenticatedMixin: + """Provide an error for not authenticated error.""" + + @property + def not_authenticated(self): + """Create an error for unauthenticated requests.""" + return HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/src/fastapi_aad_auth/oauth/__init__.py b/src/fastapi_aad_auth/oauth/__init__.py index d239737..8dd4800 100644 --- a/src/fastapi_aad_auth/oauth/__init__.py +++ b/src/fastapi_aad_auth/oauth/__init__.py @@ -1,4 +1,8 @@ -"""Handlers for oauth.""" +"""OAuth handlers.""" +from fastapi_aad_auth.oauth.state import AuthenticationState # noqa: F401 +from fastapi_aad_auth.utilities import deprecate_module # noqa: F401 -from fastapi_aad_auth.oauth.aad import AADOAuthBackend # noqa F401 -from fastapi_aad_auth.oauth.state import AuthenticationState # noqa F401 \ No newline at end of file +_DEPRECATED_VERSION = '0.2.0' + + +deprecate_module(locals(), _DEPRECATED_VERSION) diff --git a/src/fastapi_aad_auth/oauth/_base.py b/src/fastapi_aad_auth/oauth/_base.py deleted file mode 100644 index b9fd236..0000000 --- a/src/fastapi_aad_auth/oauth/_base.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Base OAuthBackend with token and session validators.""" - -import logging -from typing import Optional - -from starlette.authentication import AuthCredentials, AuthenticationBackend, UnauthenticatedUser -from starlette.requests import Request - -from fastapi_aad_auth.oauth.state import AuthenticationState - -logger = logging.getLogger(__name__) - - -class BaseOAuthBackend(AuthenticationBackend): - """Base OAuthBackend with token and session validators.""" - - def __init__(self, token_validator, session_validator=None, authenticator=None): - """Initialise the validators and authenticator.""" - self.validators = [] - if session_validator: - self.validators.append(session_validator) - self.validators.append(token_validator) - self._token_validator = token_validator - self.authenticator = authenticator - - async def authenticate(self, request): - """Authenticate a request.""" - state = self.check(request) - if state is None: - return AuthCredentials([]), UnauthenticatedUser() - return state.credentials, state.authenticated_user - - def is_authenticated(self, request): - """Check if a request is authenticated.""" - state = self.check(request) - return state is not None - - async def __call__(self, request: Request) -> Optional[AuthenticationState]: - """Check/validate a request.""" - return self.check(request) - - def check(self, request: Request) -> Optional[AuthenticationState]: - """Check/validate a request.""" - state = None - while state is None: - validator = next(self.iter_validators()) - state = validator.check(request) - logger.debug(f'Authentication state {state} from validator {validator}') - return state - - def iter_validators(self): - """Iterate over authentication validators.""" - for validator in self.validators: - yield validator - - @property - def api_auth_scheme(self): - """Get the API Authentication Schema.""" - return self._token_validator diff --git a/src/fastapi_aad_auth/oauth/aad.py b/src/fastapi_aad_auth/oauth/aad.py index 740130d..65c5939 100644 --- a/src/fastapi_aad_auth/oauth/aad.py +++ b/src/fastapi_aad_auth/oauth/aad.py @@ -1,70 +1,22 @@ -"""AAD OAuth handlers.""" -import logging -from typing import List, Optional +"""AAD handlers.""" +from fastapi_aad_auth._base.validators import SessionValidator as _SessionValidator +from fastapi_aad_auth.auth import Authenticator as _Authenticator +from fastapi_aad_auth.config import AADConfig, Config # noqa: F401 +from fastapi_aad_auth.providers.aad import AADProvider as _AADProvider +from fastapi_aad_auth.utilities import deprecate, deprecate_module -from itsdangerous import URLSafeSerializer +_DEPRECATED_VERSION = '0.2.0' -from fastapi_aad_auth.config import Config -from fastapi_aad_auth.oauth._base import BaseOAuthBackend -from fastapi_aad_auth.oauth.authenticators import AADSessionAuthenticator -from fastapi_aad_auth.oauth.state import User -from fastapi_aad_auth.oauth.validators import AADSessionValidator, AADTokenValidator, get_session_serializer -logger = logging.getLogger(__name__) +deprecate_module(locals(), _DEPRECATED_VERSION, replaced_by=f'{_AADProvider.__module__} and {_Authenticator.__module__}') -class AADOAuthBackend(BaseOAuthBackend): - """fastapi auth backend for Azure Active Directory.""" - - def __init__( - self, - session_serializer: URLSafeSerializer, - client_id: str, - tenant_id: str, - redirect_path: str = '/login/oauth/redirect', - prompt: Optional[str] = None, - client_secret: Optional[str] = None, - scopes: Optional[List[str]] = None, - enabled: bool = True, - client_app_ids: Optional[List[str]] = None, - strict_token: bool = True, - api_audience: Optional[str] = None, - redirect_uri: Optional[str] = None, - domain_hint: Optional[str] = None, - user_klass: type = User): - """Initialise the auth backend. - - Args: - session_serializer: Session serializer object - client_id: Client ID from Azure App Registration - tenant_id: Tenant ID to connect to for Azure App Registration - - Keyword Args: - redirect_path: Path to redirect to on return - prompt: Prompt options for Azure AD - client_secret: Client secret value - scopes: Additional scopes requested - enabled: Boolean flag to enable this backend - client_app_ids: List of client apps to accept tokens from - strict_token: Strictly evaluate token - api_audience: Api Audience declared in Azure AD App registration - redirect_uri: Full URI for post authentication callbacks - domain_hint: Hint for the domain - user_klass: Class to use as a user. - """ - self.session_serializer = session_serializer - self.enabled = enabled - token_validator = AADTokenValidator(client_id=client_id, tenant_id=tenant_id, api_audience=api_audience, - client_app_ids=client_app_ids, scopes={}, enabled=enabled, strict=strict_token, - user_klass=user_klass) - session_validator = AADSessionValidator(session_serializer) - session_authenticator = AADSessionAuthenticator(session_validator=session_validator, token_validator=token_validator, - client_id=client_id, tenant_id=tenant_id, redirect_path=redirect_path, - prompt=prompt, client_secret=client_secret, scopes=scopes, - redirect_uri=redirect_uri, domain_hint=domain_hint) - super().__init__(token_validator, session_validator, authenticator=session_authenticator) +@deprecate(_DEPRECATED_VERSION, replaced_by=f'{_AADProvider.__module__}:{_AADProvider.__name__}') +class AADOAuthBackend(_AADProvider): # noqa: D101 + __doc__ = _AADProvider.__doc__ @classmethod + @deprecate(_DEPRECATED_VERSION, replaced_by=f'{_AADProvider.__module__}:{_AADProvider.__name__}.from_config and {_Authenticator.__module__}:{_Authenticator.__name__}') def from_config(cls, config: Config): """Load the auth backend from a config. @@ -74,17 +26,7 @@ def from_config(cls, config: Config): Keyword Args: user_klass: The class to use as a user """ - auth_serializer = get_session_serializer(config.auth_session.secret.get_secret_value(), - config.auth_session.salt.get_secret_value()) - client_secret = config.aad.client_secret - if client_secret is not None: - client_secret = client_secret.get_secret_value() # type: ignore - - return cls(session_serializer=auth_serializer, client_id=config.aad.client_id.get_secret_value(), - tenant_id=config.aad.tenant_id.get_secret_value(), - redirect_path=config.routing.login_redirect_path, - client_secret=client_secret, enabled=config.enabled, # type: ignore - scopes=config.aad.scopes, client_app_ids=config.aad.client_app_ids, - strict_token=config.aad.strict, api_audience=config.aad.api_audience, - prompt=config.aad.prompt, domain_hint=config.aad.domain_hint, - redirect_uri=config.aad.redirect_uri, user_klass=config.user_klass) + session_validator = _SessionValidator.get_session_serializer(config.auth_session.secret.get_secret_value(), + config.auth_session.salt.get_secret_value()) + provider = super().from_config(session_validator, config, config.aad) + return provider.auth_backend diff --git a/src/fastapi_aad_auth/oauth/authenticators/__init__.py b/src/fastapi_aad_auth/oauth/authenticators/__init__.py deleted file mode 100644 index 8bf0446..0000000 --- a/src/fastapi_aad_auth/oauth/authenticators/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from fastapi_aad_auth.oauth.authenticators.session import AADSessionAuthenticator # noqa F401 \ No newline at end of file diff --git a/src/fastapi_aad_auth/oauth/authenticators/session.py b/src/fastapi_aad_auth/oauth/authenticators/session.py deleted file mode 100644 index 83c53df..0000000 --- a/src/fastapi_aad_auth/oauth/authenticators/session.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Authenticator for interactive (UI) sessions.""" -import base64 -import logging - -import msal -from pkg_resources import resource_string -from starlette.requests import Request -from starlette.responses import RedirectResponse - -from fastapi_aad_auth.errors import ConfigurationError -from fastapi_aad_auth.oauth.state import AuthenticationState - -logger = logging.getLogger(__name__) - - -class SessionAuthenticator: - """Authenticator for interactive (UI) sessions.""" - - def __init__(self, session_validator, token_validator): - """Initialise the session authenticator.""" - self._session_validator = session_validator - self._token_validator = token_validator - - def redirect_if_authenticated(self, auth_state, redirect='/'): - """Redirect to a target if authenticated.""" - if auth_state.is_authenticated(): - logger.info(f'Logged in, redirecting to {redirect}') - else: - redirect = '/login' - return RedirectResponse(redirect) - - def redirect_to_provider_login(self, auth_state, request): - """Redirect to the provider login.""" - logger.debug(f'state {auth_state}') - auth_state.save_to_session(self._session_validator._session_serializer, request.session) - authorization_url = self._get_authorization_url(request, auth_state.session_state) - return RedirectResponse(authorization_url) - - def _get_authorization_url(self, request, session_state): - raise NotImplementedError('Implement in specific subclass') - - def process_login_request(self, request, force=False, redirect='/'): - """Process the provider login request.""" - logger.debug(f'Logging in - request url {request.url}') - auth_state = self._session_validator.get_state_from_session(request) - if auth_state.is_authenticated() and not force: - logger.debug(f'Authenticated - redirecting {auth_state}') - response = self.redirect_if_authenticated(auth_state) - else: - # Set the redirect parameter here - self._session_validator.set_post_auth_redirect(request, request.query_params.get('redirect', redirect)) - logger.debug(f'No Auth state - redirecting to provider login {auth_state}') - response = self.redirect_to_provider_login(auth_state, request) - return response - - def process_login_callback(self, request): - """Process the provider login callback.""" - code = request.query_params.get('code', None) - state = request.query_params.get('state', None) - if state is None or code is None: - return # not authenticated - auth_state = self._session_validator.get_state_from_session(request) - auth_state.check_session_state(state) - token = self._process_code(request, auth_state, code) - user = self._get_user_from_token(token) - authenticated_state = AuthenticationState.authenticate_as(user, self._session_validator._session_serializer, request.session) - redirect = self._session_validator.pop_post_auth_redirect(request) - return self.redirect_if_authenticated(authenticated_state, redirect=redirect) - - def _process_code(self, request, auth_state, code): - raise NotImplementedError('Implement in subclass') - - def get_access_token(self, user): - """Get the access token for the user.""" - raise NotImplementedError('Implement in subclass') - - def get_access_token_from_request(self, request: Request): - """Get the access token from a request object.""" - auth_state = self._session_validator.get_state_from_session(request) - if auth_state.is_authenticated(): - return self.get_access_token(auth_state.user)['access_token'] - return None - - def get_user_from_request(self, request: Request): - """Get the user from a request object.""" - auth_state = self._session_validator.get_state_from_session(request) - return auth_state.user - - def _get_user_from_token(self, token, options=None): - validated_claims = self._token_validator.validate_token(token, options=options) - return self._token_validator._get_user_from_claims(validated_claims) - - def _add_redirect_to_url(self, url, post_redirect=None): - if post_redirect is not None: - url = f'{url}?redirect={post_redirect}' - return url - - def get_login_button(self, url, post_redirect='/'): - """Get a UI login button.""" - url = self._add_redirect_to_url(url, post_redirect) - return f'Sign in' - - def logout(self, request): - """Process a logout request.""" - AuthenticationState.logout(self._session_validator._session_serializer, request.session) - - def pop_post_auth_redirect(self, *args, **kwargs): - """Clear post-authentication redirects.""" - return self._session_validator.pop_post_auth_redirect(*args, **kwargs) - - def set_post_auth_redirect(self, *args, **kwargs): - """Set post-authentication redirects.""" - self._session_validator.set_post_auth_redirect(*args, **kwargs) - - -class AADSessionAuthenticator(SessionAuthenticator): - """AAD Authenticator for interactive (UI) sessions.""" - - def __init__( - self, - session_validator, - token_validator, - client_id, - tenant_id, - redirect_path='/login/oauth/redirect', - prompt=None, - client_secret=None, - scopes=None, - redirect_uri=None, - domain_hint=None): - """Initialise AAD Authenticator for interactive (UI) sessions.""" - super().__init__(session_validator, token_validator) - self._redirect_path = redirect_path - self._redirect_uri = redirect_uri - self._domain_hint = domain_hint - self._prompt = prompt - self.client_id = client_id - if scopes is None: - scopes = [f'api://{self.client_id}'] - elif isinstance(scopes, str): - scopes = [scopes] - self._scopes = scopes - self._authority = f'https://login.microsoftonline.com/{tenant_id}' - - if client_secret is not None: - logger.info('Client secret provided, using Confidential Client') - self.msal_application = msal.ConfidentialClientApplication( - client_id, - authority=self._authority, - client_credential=client_secret) - else: - logger.info('Client secret not provided, using Public Client') - self.msal_application = msal.PublicClientApplication( - client_id, - authority=self._authority) - - def _build_redirect_uri(self, request): - if self._redirect_uri: - redirect_uri = self._redirect_uri - else: - if request.url.port is None or (request.url.port == 80 and request.url.scheme == 'http') or (request.url.port == 443 and request.url.scheme == 'https'): - port = '' - else: - port = f':{request.url.port}' - redirect_uri = f'{request.url.scheme}://{request.url.hostname}{port}{self._redirect_path}' - return redirect_uri - - def _process_code(self, request, auth_state, code): - # Let's build up the redirect_uri - result = self.msal_application.acquire_token_by_authorization_code(code, scopes=[], - redirect_uri=self._build_redirect_uri(request)) - logger.debug(f'Result {result}') - if 'error' in result and result['error']: - raise ConfigurationError(result) - return result['id_token'] - - def _get_user_from_token(self, token, options=None): - if options is None: - options = self._token_validator._claims_options - options.pop('azp', None) - options.pop('appid', None) - return super()._get_user_from_token(token, options=options) - - def _get_authorization_url(self, request, session_state): - return self.msal_application.get_authorization_request_url([], - state=session_state, - claims_challenge='{"id_token": {"roles": {"essential": true} } }', - redirect_uri=self._build_redirect_uri(request), - prompt=self._prompt, - domain_hint=self._domain_hint) - - def get_login_button(self, url, post_redirect='/'): - """Get the AAD Login Button.""" - url = self._add_redirect_to_url(url, post_redirect) - logo = base64.b64encode(resource_string('fastapi_aad_auth.oauth', 'ms-logo.png')).decode() - return f'' - - def get_access_token(self, user): - """Get the access token for the user.""" - result = None - account = None - if user.username: - account = self.msal_application.get_accounts(user.username) - if account: - account = account[0] - logger.info(account) - # This needs you to register the openid api - result = self.msal_application.acquire_token_silent_with_error(scopes=[f'api://{self.client_id}/openid'], account=account) - logger.info(result) - if result is None: - raise ValueError('Token not found') - else: - return {'token_type': result['token_type'], - 'expires_in': result['expires_in'], - 'access_token': result['access_token']} diff --git a/src/fastapi_aad_auth/oauth/state.py b/src/fastapi_aad_auth/oauth/state.py index 22829b3..fc3e096 100644 --- a/src/fastapi_aad_auth/oauth/state.py +++ b/src/fastapi_aad_auth/oauth/state.py @@ -1,131 +1,18 @@ -"""Authentication State Handler.""" -from enum import Enum -import json -import logging -from typing import List, Optional -import uuid +"""Authentication State.""" +from fastapi_aad_auth._base.state import AuthenticationState as _AuthenticationState, User as _User +from fastapi_aad_auth.utilities import deprecate, deprecate_module -from itsdangerous import URLSafeSerializer -from itsdangerous.exc import BadSignature -from pydantic import BaseModel, root_validator -from starlette.authentication import AuthCredentials, AuthenticationError, SimpleUser, UnauthenticatedUser +_DEPRECATED_VERSION = '0.2.0' -logger = logging.getLogger(__name__) +deprecate_module(locals(), _DEPRECATED_VERSION, replaced_by=_AuthenticationState.__module__) -SESSION_STORE_KEY = 'auth' +@deprecate(_DEPRECATED_VERSION, replaced_by=f'{_AuthenticationState.__module__}:{_AuthenticationState.__name__}') +class AuthenticationState(_AuthenticationState): # noqa: D101 + __doc__ = _AuthenticationState.__doc__ -class AuthenticationOptions(Enum): - """Authentication Options.""" - unauthenticated = 0 - not_allowed = -1 - authenticated = 1 - - -class User(BaseModel): - """User Model.""" - name: str - email: str - username: str - roles: Optional[List[str]] = None - groups: Optional[List[str]] = None - - @property - def permissions(self): - """User Permissions.""" - return [] - - -class AuthenticationState(BaseModel): - """Authentication State.""" - session_state: str = str(uuid.uuid4()) - state: AuthenticationOptions = AuthenticationOptions.unauthenticated - user: Optional[User] = None - - @root_validator(pre=True) - def _validate_user(cls, values): - if values.get('user', None) is None: - values['state'] = AuthenticationOptions.unauthenticated - return values - - def check_session_state(self, session_state): - """Check state againste session state.""" - if session_state != self.session_state: - raise AuthenticationError("Session states do not match") - return True - - def store(self, serializer): - """Store in serializer.""" - return serializer.dumps(self.json()) - - @classmethod - def load(cls, serializer: URLSafeSerializer, encoded_state: Optional[str] = None): - """Load from encoded state. - - Args: - serializer: Serializer object containing the en/decoding secrets - Keyword Args: - encoded_state: The encoded state to be decoded - """ - if encoded_state: - try: - state = json.loads(serializer.loads(encoded_state)) - loaded_state = cls(**state) - except BadSignature: - loaded_state = cls() - else: - loaded_state = cls() - return loaded_state - - @classmethod - def logout(cls, serializer: URLSafeSerializer, session): - """Clear the sessions state.""" - state = cls.load_from_session(serializer, session) - state.user = None - state.state = AuthenticationOptions.unauthenticated - session[SESSION_STORE_KEY] = state.store(serializer) - - @classmethod - def load_from_session(cls, serializer: URLSafeSerializer, session): - """Load from a session.""" - return cls.load(serializer, session.get(SESSION_STORE_KEY, None)) - - def save_to_session(self, serializer: URLSafeSerializer, session): - """Save to a session.""" - session[SESSION_STORE_KEY] = self.store(serializer) - return session - - def is_authenticated(self): - """Check if the state is authenticated.""" - return self.user is not None and self.state == AuthenticationOptions.authenticated - - @property - def authenticated_user(self): - """Get the authenticated user.""" - if self.is_authenticated() and self.user: - if isinstance(self.user, User): - return SimpleUser(self.user.email) - return UnauthenticatedUser() - - @property - def credentials(self): - """Get the credentials object.""" - if self.user and self.is_authenticated(): - return AuthCredentials(['authenticated'] + self.user.permissions) - else: - return AuthCredentials() - - @classmethod - def authenticate_as(cls, user, serializer, session): - """Store the authenticated user.""" - state = cls(user=user, state=AuthenticationOptions.authenticated) - if serializer is not None and session is not None: - state.save_to_session(serializer, session) - return state - - @classmethod - def as_unauthenticated(cls, serializer, session): - """Store as an un-authenticated user.""" - return cls.authenticate_as(None, serializer, session) +@deprecate(_DEPRECATED_VERSION, replaced_by=f'{_User.__module__}:{_User.__name__}') +class User(_User): # noqa: D101 + __doc__ = _User.__doc__ diff --git a/src/fastapi_aad_auth/oauth/validators/__init__.py b/src/fastapi_aad_auth/oauth/validators/__init__.py deleted file mode 100644 index be6510b..0000000 --- a/src/fastapi_aad_auth/oauth/validators/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Validators for different authentication methods.""" -from fastapi_aad_auth.oauth.validators.session import SessionValidator as AADSessionValidator, get_session_serializer # noqa F401 -from fastapi_aad_auth.oauth.validators.token import AADTokenValidator # noqa F401 diff --git a/src/fastapi_aad_auth/oauth/validators/token.py b/src/fastapi_aad_auth/oauth/validators/token.py deleted file mode 100644 index 0cd6f31..0000000 --- a/src/fastapi_aad_auth/oauth/validators/token.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Validator for token based authentication.""" -import logging -from typing import List, Optional - -from authlib.jose import errors as jwt_errors, jwk, jwt -from authlib.jose.util import extract_header -from cryptography.hazmat.primitives import serialization -from fastapi import HTTPException, status -from fastapi.security import OAuth2AuthorizationCodeBearer -from fastapi.security.utils import get_authorization_scheme_param -from pydantic import BaseModel -import requests -from starlette.middleware.authentication import AuthenticationError -from starlette.requests import Request -from starlette.status import HTTP_401_UNAUTHORIZED - -from fastapi_aad_auth.oauth.state import AuthenticationState, User - -logger = logging.getLogger(__name__) - - -class InitOAuth(BaseModel): - """OAuth information for openapi docs.""" - clientId: str - scopes: str - usePkceWithAuthorizationCodeGrant: bool - - -class TokenValidator(OAuth2AuthorizationCodeBearer): - """Validator for token based authentication.""" - - def __init__( - self, - client_id: str, - authorizationUrl: str, - tokenUrl: str, - api_audience: str = None, - scheme_name: str = None, - scopes: dict = None, - auto_error: bool = False, - enabled: bool = True, - use_pkce: bool = True, - user_klass: type = User - ): - """Initialise validator for token based authentication.""" - super().__init__(authorizationUrl=authorizationUrl, tokenUrl=tokenUrl, refreshUrl=api_audience, scheme_name=scheme_name, scopes=scopes, auto_error=auto_error) - self.client_id = client_id - self.enabled = enabled - if api_audience is None: - api_audience = f"api://{client_id}" - self.api_audience = api_audience - self._use_pkce = use_pkce - self._user_klass = user_klass - - def check(self, request): - """Check the authentication from the request.""" - token = self.get_token(request) - if token is None: - return AuthenticationState.as_unauthenticated(None, None) - claims = self.validate_token(token) - user = self._get_user_from_claims(claims) - return AuthenticationState.authenticate_as(user, None, None) - - def get_token(self, request): - """Get the token from the request.""" - authorization = request.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != "bearer": - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - else: - return None # pragma: nocover - return param - - @property - def init_oauth(self): - """Get the openapi docs config.""" - return InitOAuth(clientId=self.client_id, scopes=f'{self.api_audience}/openid', usePkceWithAuthorizationCodeGrant=self._use_pkce).dict() - - def _validate_claims(self, claims, options=None): - if options is None: - options = self._claims_options - claims.options = options - try: - claims.validate() - except jwt_errors.ExpiredTokenError as e: - logger.error(f'Expired token:\n\t{self._compare_claims(claims)}') - raise AuthenticationError(f"Token is expired {e.args}") - except jwt_errors.InvalidClaimError as e: - logger.error(f'Invalid claims:\n\t{self._compare_claims(claims)}') - raise AuthenticationError(f"Invalid claims {e.args}") - except jwt_errors.MissingClaimError as e: - logger.error(f'Missing claims:\n\t{self._compare_claims(claims)}') - raise AuthenticationError(f"Missing claims {e.args}") - except Exception as e: - logger.exception('Unable to parse error') - raise AuthenticationError(f"Unable to parse authentication token {e.args}") - return claims - - @property - def _claims_options(self): - options = {"sub": {"essential": True}, - "aud": {"essential": True, "values": [self.api_audience]}, - "exp": {"essential": True}, - "nbf": {"essential": True}, - "iat": {"essential": True}} - return options - - def _decode_token(self, token): - raise NotImplementedError('Implement in base class') - - def validate_token(self, token, options=None): - """Validate provided token.""" - claims = self._decode_token(token) - return self._validate_claims(claims, options) - - @staticmethod - def _compare_claims(claims): - return '\n\t'.join([f'{key}: {value} - {claims.options.get(key, None)}' for key, value in claims.items()]) - - def _get_user_from_claims(self, claims): - raise NotImplementedError('Implement in sub class') - - # TODO change pattern to better depend on alternate method - async def __call__(self, request: Request) -> AuthenticationState: # type: ignore - """Validate the request authentication. - - Returns an AuthenticationState object or raises an Unauthorized eror - """ - result = self.check(request) - logger.info(f'Identified state {result}') - if not result.is_authenticated(): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - return result - - -class AADTokenValidator(TokenValidator): - """Validator for AAD token based authentication.""" - - def __init__(self, - client_id: str, - tenant_id: str, - api_audience: str = None, - scheme_name: str = None, - scopes: dict = None, - auto_error: bool = False, - enabled: bool = True, - use_pkce: bool = True, - strict: bool = True, - client_app_ids: Optional[List[str]] = None, - user_klass: type = User): - """Initialise validator for AAD token based authentication.""" - authorization_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize" - token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" - self.key_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys" - self.tenant_id = tenant_id - super().__init__(client_id=client_id, authorizationUrl=authorization_url, tokenUrl=token_url, api_audience=api_audience, scheme_name=scheme_name, - scopes=scopes, auto_error=auto_error, enabled=enabled, use_pkce=use_pkce, user_klass=user_klass) - self.strict = strict - if client_app_ids is None: - client_app_ids = [] - self.client_app_ids = client_app_ids - - def _get_ms_jwk(self, token): - try: - logger.info(f'Getting signing keys from {self.key_url}') - jwks = requests.get(self.key_url).json() - token_header = token.split(".")[0].encode() - unverified_header = extract_header(token_header, jwt_errors.DecodeError) - for key in jwks["keys"]: - if key["kid"] == unverified_header["kid"]: - logger.info(f'Identified key {key["kid"]}') - return jwk.loads(key) - except jwt_errors.DecodeError: - logger.exception('Error parsing signing keys') - raise AuthenticationError("Unable to parse signing keys") - - def _decode_token(self, token): - jwk_ = self._get_ms_jwk(token) - claims = None - logger.debug(f'Key is {jwk_}') - try: - if hasattr(jwk, 'public_bytes'): - public_bytes = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) - else: - public_bytes = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) - claims = jwt.decode( - token, - public_bytes, - ) - except Exception: - logger.exception('Unable to parse error') - raise AuthenticationError("Unable to parse authentication token") - return claims - - def _validate_claims(self, claims, options=None): - if options is None: - options = self._claims_options - # We need to do some 1.0/2.0 handling because it doesn't seem to work properly - # TODO: validate whether we want this claim here? - # TODO: validate whether the user is approved for the app - if 'appid' in options and 'azp' in options: - if 'appid' not in claims: - options.pop('appid') - elif 'azp'not in claims: - options.pop('azp') - if not ('appid' in claims or 'azp' in claims): - if self.strict: - logger.error('No appid/azp claims found in token') - raise AuthenticationError('No appid/azp claims found in token') - else: - logger.warning('No appid/azp claims found in token - we are ignoring for now') - return super()._validate_claims(claims, options) - - @property - def _claims_options(self): - options = super()._claims_options - options["iss"] = {"essential": True, "values": [f"https://sts.windows.net/{self.tenant_id}/", f"https://login.microsoftonline.com/{self.tenant_id}/v2.0"]} - options["aud"] = {"essential": True, "values": [self.api_audience] + [self.client_id] + self.client_app_ids} - options["azp"] = {"essential": True, "values": [self.client_id] + self.client_app_ids} - options["appid"] = {"essential": True, "values": [self.client_id] + self.client_app_ids} - logger.debug(f'Claims options {options}') - return options - - def _get_user_from_claims(self, claims): - logger.debug(f'Processing claims: {claims}') - username_key = 'preferred_username' - if username_key not in claims: - username_key = 'unique_name' - if 'name' not in claims and 'appid' in claims: - # This is an application/service principal - return self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None)) - - else: - return self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None)) diff --git a/src/fastapi_aad_auth/providers/__init__.py b/src/fastapi_aad_auth/providers/__init__.py new file mode 100644 index 0000000..8b515d6 --- /dev/null +++ b/src/fastapi_aad_auth/providers/__init__.py @@ -0,0 +1,3 @@ +"""Handlers for oauth.""" + +from fastapi_aad_auth.providers.aad import AADConfig, AADProvider # noqa F401 diff --git a/src/fastapi_aad_auth/providers/aad.py b/src/fastapi_aad_auth/providers/aad.py new file mode 100644 index 0000000..3d2f9da --- /dev/null +++ b/src/fastapi_aad_auth/providers/aad.py @@ -0,0 +1,347 @@ +"""AAD OAuth handlers.""" + +import base64 +from typing import List, Optional + +from authlib.jose import errors as jwt_errors, jwk, jwt +from authlib.jose.util import extract_header +from cryptography.hazmat.primitives import serialization +import msal +from pkg_resources import resource_string +from pydantic import BaseSettings as _BaseSettings, Field, HttpUrl, PrivateAttr, SecretStr, validator +import requests +from starlette.middleware.authentication import AuthenticationError +from starlette.requests import Request + +from fastapi_aad_auth._base.authenticators import SessionAuthenticator +from fastapi_aad_auth._base.provider import Provider +from fastapi_aad_auth._base.state import User +from fastapi_aad_auth._base.validators import SessionValidator, TokenValidator +from fastapi_aad_auth.errors import ConfigurationError +from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, expand_doc, is_deprecated, list_from_env, urls + + +class BaseSettings(DeprecatableFieldsMixin, _BaseSettings): + """Base Settings with Deprecatable Fields.""" + pass + + +class AADSessionAuthenticator(SessionAuthenticator): + """AAD Authenticator for interactive (UI) sessions.""" + + def __init__( + self, + session_validator, + token_validator, + client_id, + tenant_id, + redirect_path='/oauth/aad/redirect', + prompt=None, + client_secret=None, + scopes=None, + redirect_uri=None, + domain_hint=None): + """Initialise AAD Authenticator for interactive (UI) sessions.""" + super().__init__(session_validator, token_validator) + self._redirect_path = redirect_path + self._redirect_uri = redirect_uri + self._domain_hint = domain_hint + self._prompt = prompt + self.client_id = client_id + if scopes is None: + scopes = [f'api://{self.client_id}'] + elif isinstance(scopes, str): + scopes = [scopes] + self._scopes = scopes + self._authority = f'https://login.microsoftonline.com/{tenant_id}' + + if client_secret is not None: + self.logger.info('Client secret provided, using Confidential Client') + self.msal_application = msal.ConfidentialClientApplication( + client_id, + authority=self._authority, + client_credential=client_secret) + else: + self.logger.info('Client secret not provided, using Public Client') + self.msal_application = msal.PublicClientApplication( + client_id, + authority=self._authority) + + def _build_redirect_uri(self, request: Request): + if self._redirect_uri: + redirect_uri = self._redirect_uri + else: + if request.url.port is None or (request.url.port == 80 and request.url.scheme == 'http') or (request.url.port == 443 and request.url.scheme == 'https'): + port = '' + else: + port = f':{request.url.port}' + redirect_uri = f'{request.url.scheme}://{request.url.hostname}{port}{self._redirect_path}' + return redirect_uri + + def _process_code(self, request: Request, auth_state, code): + # Let's build up the redirect_uri + result = self.msal_application.acquire_token_by_authorization_code(code, scopes=[], + redirect_uri=self._build_redirect_uri(request)) + self.logger.debug(f'Result {result}') + if 'error' in result and result['error']: + raise ConfigurationError(result) + return result['id_token'] + + def _get_user_from_token(self, token, options=None): + if options is None: + options = self._token_validator._claims_options + options.pop('azp', None) + options.pop('appid', None) + return super()._get_user_from_token(token, options=options) + + def _get_authorization_url(self, request, session_state): + return self.msal_application.get_authorization_request_url([], + state=session_state, + claims_challenge='{"id_token": {"roles": {"essential": true} } }', + redirect_uri=self._build_redirect_uri(request), + prompt=self._prompt, + domain_hint=self._domain_hint) + + def get_access_token(self, user): + """Get the access token for the user.""" + result = None + account = None + if user.username: + account = self.msal_application.get_accounts(user.username) + if account: + account = account[0] + self.logger.info(account) + # This needs you to register the openid api + result = self.msal_application.acquire_token_silent_with_error(scopes=[f'api://{self.client_id}/openid'], account=account) + self.logger.info(result) + if result is None: + raise ValueError('Token not found') + else: + return {'token_type': result['token_type'], + 'expires_in': result['expires_in'], + 'access_token': result['access_token']} + + +class AADTokenValidator(TokenValidator): + """Validator for AAD token based authentication.""" + + def __init__(self, + client_id: str, + tenant_id: str, + api_audience: str = None, + scheme_name: str = None, + scopes: dict = None, + auto_error: bool = False, + enabled: bool = True, + use_pkce: bool = True, + strict: bool = True, + client_app_ids: Optional[List[str]] = None, + user_klass: type = User): + """Initialise validator for AAD token based authentication.""" + authorization_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize" + token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" + self.key_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys" + self.tenant_id = tenant_id + super().__init__(client_id=client_id, authorizationUrl=authorization_url, tokenUrl=token_url, api_audience=api_audience, scheme_name=scheme_name, + scopes=scopes, auto_error=auto_error, enabled=enabled, use_pkce=use_pkce, user_klass=user_klass) + self.strict = strict + if client_app_ids is None: + client_app_ids = [] + self.client_app_ids = client_app_ids + + def _get_ms_jwk(self, token): + try: + self.logger.info(f'Getting signing keys from {self.key_url}') + jwks = requests.get(self.key_url).json() + token_header = token.split(".")[0].encode() + unverified_header = extract_header(token_header, jwt_errors.DecodeError) + for key in jwks["keys"]: + if key["kid"] == unverified_header["kid"]: + self.logger.info(f'Identified key {key["kid"]}') + return jwk.loads(key) + except jwt_errors.DecodeError: + self.logger.exception('Error parsing signing keys') + raise AuthenticationError("Unable to parse signing keys") + + def _decode_token(self, token): + jwk_ = self._get_ms_jwk(token) + claims = None + self.logger.debug(f'Key is {jwk_}') + try: + if hasattr(jwk, 'public_bytes'): + public_bytes = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + else: + public_bytes = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + claims = jwt.decode( + token, + public_bytes, + ) + except Exception: + self.logger.exception('Unable to parse error') + raise AuthenticationError("Unable to parse authentication token") + return claims + + def _validate_claims(self, claims, options=None): + if options is None: + options = self._claims_options + # We need to do some 1.0/2.0 handling because it doesn't seem to work properly + # TODO: validate whether we want this claim here? + # TODO: validate whether the user is approved for the app + if 'appid' in options and 'azp' in options: + if 'appid' not in claims: + options.pop('appid') + elif 'azp'not in claims: + options.pop('azp') + if not ('appid' in claims or 'azp' in claims): + if self.strict: + self.logger.error('No appid/azp claims found in token') + raise AuthenticationError('No appid/azp claims found in token') + else: + self.logger.warning('No appid/azp claims found in token - we are ignoring for now') + return super()._validate_claims(claims, options) + + @property + def _claims_options(self): + options = super()._claims_options + options["iss"] = {"essential": True, "values": [f"https://sts.windows.net/{self.tenant_id}/", f"https://login.microsoftonline.com/{self.tenant_id}/v2.0"]} + options["aud"] = {"essential": True, "values": [self.api_audience] + [self.client_id] + self.client_app_ids} + options["azp"] = {"essential": True, "values": [self.client_id] + self.client_app_ids} + options["appid"] = {"essential": True, "values": [self.client_id] + self.client_app_ids} + self.logger.debug(f'Claims options {options}') + return options + + def _get_user_from_claims(self, claims): + self.logger.debug(f'Processing claims: {claims}') + username_key = 'preferred_username' + if username_key not in claims: + username_key = 'unique_name' + if 'name' not in claims and 'appid' in claims: + # This is an application/service principal + return self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None)) + + else: + return self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None)) + + +class AADProvider(Provider): + """fastapi auth backend for Azure Active Directory.""" + name: str = 'aad' + + def __init__( + self, + session_validator: SessionValidator, + client_id: str, + tenant_id: str, + prompt: Optional[str] = None, + client_secret: Optional[str] = None, + scopes: Optional[List[str]] = None, + enabled: bool = True, + client_app_ids: Optional[List[str]] = None, + strict_token: bool = True, + api_audience: Optional[str] = None, + redirect_uri: Optional[str] = None, + domain_hint: Optional[str] = None, + user_klass: type = User, + oauth_base_route: str = '/oauth'): + """Initialise the auth backend. + + Args: + session_serializer: Session serializer object + client_id: Client ID from Azure App Registration + tenant_id: Tenant ID to connect to for Azure App Registration + + Keyword Args: + prompt: Prompt options for Azure AD + client_secret: Client secret value + scopes: Additional scopes requested + enabled: Boolean flag to enable this backend + client_app_ids: List of client apps to accept tokens from + strict_token: Strictly evaluate token + api_audience: Api Audience declared in Azure AD App registration + redirect_uri: Full URI for post authentication callbacks + domain_hint: Hint for the domain + user_klass: Class to use as a user. + """ + redirect_path = self._build_oauth_url(oauth_base_route, 'redirect') + token_validator = AADTokenValidator(client_id=client_id, tenant_id=tenant_id, api_audience=api_audience, + client_app_ids=client_app_ids, scopes={}, enabled=enabled, strict=strict_token, + user_klass=user_klass) + session_authenticator = AADSessionAuthenticator(session_validator=session_validator, token_validator=token_validator, + client_id=client_id, tenant_id=tenant_id, redirect_path=redirect_path, + prompt=prompt, client_secret=client_secret, scopes=scopes, + redirect_uri=redirect_uri, domain_hint=domain_hint) + super().__init__(validators=[token_validator], authenticator=session_authenticator, enabled=enabled, oauth_base_route=oauth_base_route) + + @classmethod + def from_config(cls, session_validator, config, provider_config, user_klass: Optional[type] = None): + """Load the auth backend from a config. + + Args: + session_validator (SessionValidator): the session validator to use + config: Loaded configuration + + Keyword Args: + user_klass: The class to use as a user + """ + client_secret = provider_config.client_secret + if client_secret is not None: + client_secret = client_secret.get_secret_value() # type: ignore + + if user_klass is None: + user_klass = config.user_klass + + obj = cls(session_validator=session_validator, client_id=provider_config.client_id.get_secret_value(), + tenant_id=provider_config.tenant_id.get_secret_value(), + client_secret=client_secret, enabled=config.enabled, # type: ignore + scopes=provider_config.scopes, client_app_ids=provider_config.client_app_ids, + strict_token=provider_config.strict, api_audience=provider_config.api_audience, + prompt=provider_config.prompt, domain_hint=provider_config.domain_hint, + redirect_uri=provider_config.redirect_uri, user_klass=user_klass, oauth_base_route=config.routing.oauth_base_route) + # We need to override the login and redirect etc until it is deprecated + if hasattr(config.routing, 'login_path') and config.routing.login_path and not is_deprecated(config.routing.__fields__['login_path']): + obj._login_url = config.routing.login_path + if hasattr(config.routing, 'login_redirect_path') and config.routing.login_redirect_path and not is_deprecated(config.routing.__fields__['login_redirect_path']): + obj._redirect_url = config.routing.login_redirect_path + obj.authenticator._redirect_path = obj.redirect_url # type: ignore + return obj + + def get_login_button(self, post_redirect='/'): + """Get the AAD Login Button.""" + url = urls.with_redirect(self.login_url, post_redirect) + logo = base64.b64encode(resource_string('fastapi_aad_auth.providers', 'ms-logo.png')).decode() + return f'' + + +@expand_doc +class AADConfig(BaseSettings): + """Configuration for the AAD application. + + Includes expected claims, application registration, etc. + + Can also provide additional client application ids to accept. + + A list of roles can be provided to accept (requires configuring the + roles in the AAD application registration manifest) + """ + client_id: SecretStr = Field(..., description="Application Registration Client ID", env='AAD_CLIENT_ID') + tenant_id: SecretStr = Field(..., description="Application Registration Tenant ID", env='AAD_TENANT_ID') + client_secret: Optional[SecretStr] = Field(None, description="Application Registration Client Secret (if required)", env='AAD_CLIENT_SECRET') + scopes: List[str] = Field(["Read"], description="Additional scopes requested") + client_app_ids: Optional[List[str]] = Field(None, description="Additional Client App IDs to accept tokens from (when running as a backend service)", + env='AAD_CLIENT_APP_IDS') + strict: bool = Field(True, description="Check that all claims are provided", env='AAD_STRICT_CLAIM_CHECK') + api_audience: Optional[str] = Field(None, description="Corresponds to the Application ID URI - used for token validation, defaults to api://{client_id}", + env='AAD_API_AUDIENCE') + redirect_uri: Optional[HttpUrl] = Field(None, description="The redirect URI to use - overwrites the default path handling etc", + env='AAD_REDIRECT_URI') + prompt: Optional[str] = Field(None, description="AAD prompt to request", env='AAD_PROMPT') + domain_hint: Optional[str] = Field(None, description="AAD domain hint", env='AAD_DOMAIN_HINT') + roles: Optional[List[str]] = Field(None, description="AAD roles required in claims", env='AAD_ROLES') + _provider_klass: type = PrivateAttr(AADProvider) + + class Config: # noqa D106 + env_file = '.env' + + _validate_strict = validator('strict', allow_reuse=True)(bool_from_env) + _validate_client_app_ids = validator('client_app_ids', allow_reuse=True)(list_from_env) + _validate_roles = validator('roles', allow_reuse=True)(list_from_env) diff --git a/src/fastapi_aad_auth/oauth/ms-logo.png b/src/fastapi_aad_auth/providers/ms-logo.png similarity index 100% rename from src/fastapi_aad_auth/oauth/ms-logo.png rename to src/fastapi_aad_auth/providers/ms-logo.png diff --git a/src/fastapi_aad_auth/utilities/__init__.py b/src/fastapi_aad_auth/utilities/__init__.py new file mode 100644 index 0000000..f3ad11d --- /dev/null +++ b/src/fastapi_aad_auth/utilities/__init__.py @@ -0,0 +1,45 @@ +"""Utilities.""" +from fastapi_aad_auth.utilities import logging # noqa: F401 +from fastapi_aad_auth.utilities import urls # noqa: F401 +from fastapi_aad_auth.utilities.deprecate import DeprecatableFieldsMixin, deprecate, deprecate_module, DeprecatedField, is_deprecated # noqa: F401 + + +def bool_from_env(env_value): + """Convert environment variable to boolean.""" + if isinstance(env_value, str): + env_value = env_value.lower() in ['true', '1'] + return env_value + + +def list_from_env(env_value): + """Convert environment variable to list.""" + if isinstance(env_value, str): + env_value = [u for u in env_value.split(',') if u] + return env_value + + +def expand_doc(klass): + """Expand pydantic model documentation to enable autodoc.""" + docs = ['', '', 'Keyword Args:'] + for name, field in klass.__fields__.items(): + default_str = '' + if field.default: + default_str = f' [default: ``{field.default}``]' + module = field.outer_type_.__module__ + if module != 'builtins': + if hasattr(field.outer_type_, '__origin__'): + type_ = f' ({field.outer_type_.__origin__.__name__}) ' + elif not hasattr(field.outer_type_, '__name__'): + type_ = '' + else: + type_ = f' ({module}.{field.outer_type_.__name__}) ' + else: + type_ = f' ({field.outer_type_.__name__}) ' + env_var = '' + if 'env' in field.field_info.extra: + env_var = f' (Can be set by ``{field.field_info.extra["env"]}`` environment variable)' + docs.append(f' {name}{type_}: {field.field_info.description}{default_str}{env_var}') + if klass.__doc__ is None: + klass.__doc__ = '' + klass.__doc__ += '\n'.join(docs) + return klass diff --git a/src/fastapi_aad_auth/utilities/deprecate.py b/src/fastapi_aad_auth/utilities/deprecate.py new file mode 100644 index 0000000..087283e --- /dev/null +++ b/src/fastapi_aad_auth/utilities/deprecate.py @@ -0,0 +1,148 @@ +"""Method, class, module and Field Deprecation handlers.""" +from functools import wraps +import warnings + +from pkg_resources import parse_version +from pydantic import Field + +from fastapi_aad_auth._version import get_versions + + +__version__ = get_versions()['version'] +del get_versions + +BASE_VERSION = parse_version(parse_version(__version__).base_version) # type: ignore + + +class APIDeprecationWarning(FutureWarning): + """Warning when an API component is being deprecated.""" + + +class DeprecatedError(Exception): + """Warning for developers when a component is planned to be deprecated.""" + + +@wraps(Field) +def DeprecatedField(*args, **kwargs): # noqa: D103 + deprecated_in = kwargs.get('deprecated_in') + kwargs['warn_from'] = kwargs.get('warn_from', __version__) + replaced_by = kwargs.get('replaced_by', None) + description = kwargs.get('description', '') + additional_info = kwargs.get('additional_info', '') + if not description: + description = '' + sep = '' + else: + sep = ' ' + deprecation_message = _get_deprecation_message('Field', deprecated_in, replaced_by, additional_info) + description += sep + deprecation_message + kwargs['description'] = description + kwargs['deprecated'] = True + return Field(*args, **kwargs) + + +class DeprecatableFieldsMixin: + """Mixin for deprecatable fields.""" + def __new__(cls, *args, **kwargs): + """Initialise the Field Deprecation Validator.""" + for field_name, field in cls.__fields__.items(): + if field.field_info.extra.get('deprecated', False): + if field.pre_validators is None: + field.pre_validators = [] + field.pre_validators.insert(0, cls._deprecator_validator) + return super().__new__(cls) + + @staticmethod + def _deprecator_validator(cls, value, kw, field, *args, **kwargs): + if field.field_info.extra.get('deprecated', False): + deprecated_object_description = f'{cls.__module__}:{cls.__name__}.{field.name}' + env = field.field_info.extra.get('env', None) + if env: + deprecated_object_description += f' (env={env})' + deprecated_in = field.field_info.extra.get('deprecated_in') + warn_from = field.field_info.extra.get('warn_from', __version__) + replaced_by = field.field_info.extra.get('replaced_by', None) + deprecation_message = _get_deprecation_message(deprecated_object_description, deprecated_in, replaced_by) + _warn(deprecation_message, deprecated_in, warn_from) + return value + + +def deprecate(deprecated_in, warn_from=__version__, replaced_by=None, additional_info=''): + """Deprecate a function, method or class.""" + + def wrapper(deprecated_object, deprecation_message=None): + if deprecation_message is None: + deprecated_object_description = f'{deprecated_object.__module__}:{deprecated_object.__qualname__}' + deprecation_message = _get_deprecation_message(deprecated_object_description, deprecated_in, replaced_by, additional_info) + + try: + deprecated_object.__doc__ = _update_docstring(deprecation_message, deprecated_object.__doc__) + except AttributeError: + pass + + if hasattr(deprecated_object, 'mro'): + deprecated_object.__init__ = wrapper(deprecated_object.__init__, deprecation_message) + wrapped = deprecated_object + + else: + @wraps(deprecated_object) + def wrapped(*args, **kwargs): + _warn(deprecation_message, deprecated_in, warn_from) + return deprecated_object(*args, **kwargs) + + wrapped.deprecated_in = deprecated_in + + return wrapped + + return wrapper + + +def deprecate_module(module_locals, deprecated_in, warn_from=__version__, replaced_by=None, additional_info=''): + """Deprecate a module.""" + deprecated_object_description = module_locals['__name__'] + deprecation_message = _get_deprecation_message(deprecated_object_description, deprecated_in, replaced_by, additional_info) + module_locals['__doc__'] = _update_docstring(deprecation_message, module_locals['__doc__']) + _warn(deprecation_message, deprecated_in, warn_from) + + +def _update_docstring(deprecation_message, docstring=None): + if docstring is None: + docstring = '' + else: + docstring += '\n\n' + docstring += f"DEPRECATED - {deprecation_message}" + return docstring + + +def _get_deprecation_message(deprecated_object_description, deprecated_in, replaced_by=None, additional_info=''): + replacement = '' + if replaced_by: + replacement = f', and is replaced by {replaced_by}' + if parse_version(__version__) < parse_version(deprecated_in): + tense = ' will be' + joiner = 'in' + else: + tense = ' is' + joiner = 'since' + deprecation_message = f'{deprecated_object_description}{tense} deprecated {joiner} version {deprecated_in}{replacement}{additional_info}' + return deprecation_message + + +def _warn(deprecation_message, deprecated_in, warn_from): + if BASE_VERSION >= parse_version(deprecated_in): + raise DeprecatedError(deprecation_message) + else: + if BASE_VERSION >= parse_version(warn_from): + warnings.warn(deprecation_message, APIDeprecationWarning) + else: + warnings.warn(deprecation_message, DeprecationWarning) + + +def is_deprecated(obj): + """Check if an object is deprecated.""" + deprecated_in = getattr(obj, 'deprecated_in', None) + if deprecated_in is None: + # Check if it's a field + if hasattr(obj, 'field_info'): + deprecated_in = obj.field_info.extra.get('deprecated_in', None) + return (deprecated_in is not None) and (BASE_VERSION >= parse_version(deprecated_in)) diff --git a/src/fastapi_aad_auth/utilities/logging.py b/src/fastapi_aad_auth/utilities/logging.py new file mode 100644 index 0000000..4c11817 --- /dev/null +++ b/src/fastapi_aad_auth/utilities/logging.py @@ -0,0 +1,2 @@ +"""Handle logging for fastapi_aad_auth.""" +from logging import getLogger # noqa: F401 diff --git a/src/fastapi_aad_auth/utilities/urls.py b/src/fastapi_aad_auth/utilities/urls.py new file mode 100644 index 0000000..b0f95ab --- /dev/null +++ b/src/fastapi_aad_auth/utilities/urls.py @@ -0,0 +1,18 @@ +"""URL utilities.""" + + +def with_redirect(url, post_redirect=None): + """Append a redirect query parameter.""" + if post_redirect is not None: + url = f'{url}?redirect={post_redirect}' + return url + + +def append(base_url, *args): + """Append paths together.""" + extension = '/'.join([u.strip('/') for u in args]) + if extension: + url = base_url.rstrip('/')+'/'+extension + else: + url = base_url + return url diff --git a/tests/testapp/server.py b/tests/testapp/server.py index 372db07..614f52f 100644 --- a/tests/testapp/server.py +++ b/tests/testapp/server.py @@ -10,14 +10,14 @@ import uvicorn -from fastapi_aad_auth import __version__, AADAuth, AuthenticationState +from fastapi_aad_auth import __version__, Authenticator, AuthenticationState -auth_provider = AADAuth() +auth_provider = Authenticator() router = APIRouter() @router.get('/hello') -async def hello_world(auth_state:AuthenticationState=Depends(auth_provider.api_auth_scheme)): +async def hello_world(auth_state: AuthenticationState = Depends(auth_provider.auth_backend.requires_auth(allow_session=True))): print(auth_state) return {'hello': 'world'} @@ -43,14 +43,12 @@ async def test(request): Route("/", endpoint=homepage), Route("/test", endpoint=test) ] - app = FastAPI(title='fastapi_aad_auth test app', description='Testapp for Adding Azure Active Directory Authentication for FastAPI', version=__version__, openapi_url=f"/api/v{API_VERSION}/openapi.json", docs_url='/api/docs', - swagger_ui_init_oauth=auth_provider.api_auth_scheme.init_oauth, redoc_url='/api/redoc', routes=routes)