Skip to content

Commit

Permalink
Merge pull request #33 from djpugh/feature/extensible-templates
Browse files Browse the repository at this point in the history
  • Loading branch information
djpugh authored Dec 14, 2020
2 parents 09b2183 + 39293d0 commit a6bd331
Show file tree
Hide file tree
Showing 16 changed files with 323 additions and 164 deletions.
15 changes: 15 additions & 0 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,18 @@ associated environment variable, or in the argument, which overrides all other s

auth = Authenticator(config, user_klass=MyUserClass)

Customising the UI
~~~~~~~~~~~~~~~~~~

The UI templates are rendered using Jinja2 Templates, with a customisation from :py:class:`~fastapi_aad_auth.ui.jinja.Jinja2Templates`
that uses a loader that allows a package resource to be used in place of a file (using ``{% extends <package>:<resource> %}``).

Additionally, the :py:class:`~fastapi_aad_auth.config.LoginUIConfig` has an attribute ``ui_klass`` that can be used to customise how
the context is built (note that this class should inherit from (or duck-type the public API of) :class:`~fastapi_aad_auth.ui.UI`)

These jinja templates also are structured (see :doc:`module/fastapi_aad_auth.ui` docs for the other templates) from a base template that is relatively generic:

.. literalinclude:: ../../src/fastapi_aad_auth/ui/base.html
:language: html

And can easily be extended or customised.
5 changes: 5 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.base.html.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi_aad_auth.ui:base.html
*****************************

.. literalinclude:: ../../../src/fastapi_aad_auth/ui/base.html
:language: html
5 changes: 5 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.error.html.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi_aad_auth.ui:error.html
******************************

.. literalinclude:: ../../../src/fastapi_aad_auth/ui/error.html
:language: html
5 changes: 5 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.jinja.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi_aad_auth.ui.jinja
*************************

.. automodule:: fastapi_aad_auth.ui.jinja
:members:
5 changes: 5 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.login.html.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi_aad_auth.ui:login.html
******************************

.. literalinclude:: ../../../src/fastapi_aad_auth/ui/login.html
:language: html
11 changes: 11 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,14 @@ fastapi_aad_auth.ui
*******************

.. automodule:: fastapi_aad_auth.ui
:members:

.. toctree::
:caption: Sub-modules:
:maxdepth: 1

fastapi_aad_auth.ui.jinja
fastapi_aad_auth.ui.base.html
fastapi_aad_auth.ui.error.html
fastapi_aad_auth.ui.login.html
fastapi_aad_auth.ui.user.html
5 changes: 5 additions & 0 deletions docs/source/module/fastapi_aad_auth.ui.user.html.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi_aad_auth.ui:user.html
*****************************

.. literalinclude:: ../../../src/fastapi_aad_auth/ui/user.html
:language: html
101 changes: 10 additions & 91 deletions src/fastapi_aad_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from fastapi import Depends, FastAPI
from fastapi import FastAPI
from starlette.authentication import requires
from starlette.middleware.authentication import AuthenticationError, AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Mount, request_response, Route
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from starlette.responses import RedirectResponse, Response
from starlette.routing import request_response, Route

from fastapi_aad_auth._base.backend import BaseOAuthBackend
from fastapi_aad_auth._base.state import AuthenticationState
from fastapi_aad_auth._base.validators import SessionValidator
from fastapi_aad_auth.config import Config
from fastapi_aad_auth.errors import base_error_handler, ConfigurationError
from fastapi_aad_auth.mixins import LoggingMixin
from fastapi_aad_auth.ui.jinja import Jinja2Templates
from fastapi_aad_auth.utilities import deprecate


Expand All @@ -31,12 +29,14 @@ class Authenticator(LoggingMixin):
Creates the key components based on the provided configurations
"""

def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_context: Dict[str, Any] = None, user_klass: Optional[type] = None):
"""Initialise the AAD config based on the provided configuration.
def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_context: Optional[Dict[str, Any]] = None, user_klass: Optional[type] = None):
"""Initialise the Authenticator based on the provided configuration.
Keyword Args:
config (fastapi_aad_auth.config.Config): Authentication configuration (includes ui and routing, as well as AAD Application and Tenant IDs)
add_to_base_routes (bool): Add the authentication to the router
base_context (Dict[str, Any]): a base context to provide
user_klass (type): The user class to use as part of the auth state
"""
super().__init__()
if config is None:
Expand Down Expand Up @@ -74,65 +74,8 @@ def _init_auth_backend(self):
return BaseOAuthBackend(validators)

def _init_ui(self):
login_template_path = Path(self.config.login_ui.template_file)
user_template_path = Path(self.config.login_ui.user_template_file)
login_templates = Jinja2Templates(directory=str(login_template_path.parent))
user_templates = Jinja2Templates(directory=str(user_template_path.parent))

async def login(request: Request, *args, **kwargs):
context = self._base_context.copy() # type: ignore
if not self.config.enabled or request.user.is_authenticated:
# This is authenticated so go straight to the homepage
return RedirectResponse(self.config.routing.home_path)
context['request'] = request # type: ignore
if 'login' not in context or context['login'] is None: # type: ignore
post_redirect = self._session_validator.pop_post_auth_redirect(request)
context['login'] = '<br>'.join([provider.get_login_button(post_redirect) for provider in self._providers]) # type: ignore
return login_templates.TemplateResponse(login_template_path.name, context) # type: ignore

routes = [Route(self.config.routing.landing_path, endpoint=login, methods=['GET'], name='login'),
Mount(self.config.login_ui.static_path, StaticFiles(directory=str(self.config.login_ui.static_directory)), name='static-login')]

if self.config.routing.user_path:

@self.auth_required()
async def get_user(request: Request):
context = self._base_context.copy() # type: ignore
self.logger.debug(f'Getting token for {request.user}')
context['request'] = request # type: ignore
if self.config.enabled:
self.logger.debug(f'Auth {request.auth}')
try:
context['user'] = self._session_validator.get_state_from_session(request).user
except ValueError:
# If we have one provider, we can force the login, otherwise...
return self.__force_authenticate(request)
else:
self.logger.debug('Auth not enabled')
context['token'] = None # type: ignore
return user_templates.TemplateResponse(user_template_path.name, context)

async def get_token(request: Request, auth_state: AuthenticationState = Depends(self.auth_backend)):
if not isinstance(auth_state, AuthenticationState):
user = self.__get_user_from_request(request)
else:
user = auth_state.user
if hasattr(user, 'username'): # type: ignore
access_token = self.__get_access_token(user)
if access_token:
# We want to get the token for each provider that is authenticated
return JSONResponse(access_token) # type: ignore
else:
if any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]):
# If we have one provider, we can force the login, otherwise we need to request which login route
return self.__force_authenticate(request)
else:
return JSONResponse('Unable to access token as user has not authenticated via session')
return RedirectResponse(f'{self.config.routing.landing_path}?redirect=/me/token')

routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'),
Route(f'{self.config.routing.user_path}/token', endpoint=get_token, methods=['GET'], name='get-token')]
return routes
ui = self.config.login_ui.ui_klass(self.config, self, self._base_context)
return ui.routes

def _init_auth_routes(self):

Expand All @@ -150,30 +93,6 @@ async def logout(request: Request):
# We have a deprecated behaviour here
return routes

def __force_authenticate(self, request: Request):
if len(self._providers) == 1:
return self._providers[0].authenticator.process_login_request(request, force=True, redirect=request.url.path)
else:
return RedirectResponse(f'{self.config.routing.landing_path}?redirect={request.url.path}')

def __get_access_token(self, user):
access_token = None
while access_token is None:
provider = next(self._providers)
try:
access_token = provider.autheticator.get_access_token(user)
except ValueError:
pass
return access_token

def __get_user_from_request(self, request: Request):
if hasattr(request.user, 'username'):
user = request.user
else:
auth_state = self.auth_backend.check(request)
user = auth_state.user
return user

def _set_error_handlers(self, app):
error_template_path = Path(self.config.login_ui.error_template_file)
error_templates = Jinja2Templates(directory=str(error_template_path.parent))
Expand Down
27 changes: 11 additions & 16 deletions src/fastapi_aad_auth/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""fastapi_aad_auth configuration options."""
import importlib
from typing import Dict, List, Optional, Union
import uuid

from pkg_resources import resource_filename
from pydantic import BaseSettings as _BaseSettings, DirectoryPath, Field, FilePath, SecretStr, validator

from fastapi_aad_auth.providers.aad import AADConfig
from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, DeprecatedField, expand_doc
from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, DeprecatedField, expand_doc, klass_from_str


class BaseSettings(DeprecatableFieldsMixin, _BaseSettings):
Expand Down Expand Up @@ -74,10 +73,15 @@ class LoginUIConfig(BaseSettings):
description="Path to mount the login static dir in",
env='FASTAPI_AUTH_LOGIN_STATIC_PATH')
context: Optional[Dict[str, str]] = Field(None, description="Any additional context variables required for the template")
ui_klass: type = Field('fastapi_aad_auth.ui:UI',
description="UI class to use to handle creating and returning the routes for the login, error and user screens, this will be treated as an import path "
"if provided as a string, with the last part the class to load", env='FASTAPI_AUTH_UI_KLASS')

class Config: # noqa D106
env_file = '.env'

_validate_klass = validator('ui_klass', pre=True, always=True, allow_reuse=True)(klass_from_str)


@expand_doc
class AuthSessionConfig(BaseSettings):
Expand Down Expand Up @@ -122,7 +126,10 @@ class Config: # noqa D106

@expand_doc
class Config(BaseSettings):
"""The overall configuration for the AAD authentication."""
"""The overall configuration for the AAD authentication.
Provides the overall configuration and parameters.
"""

enabled: bool = Field(True, description="Enable authentication", env='FASTAPI_AUTH_ENABLED')
providers: List[Union[AADConfig]] = Field(None, description="The provider configurations to use")
Expand Down Expand Up @@ -175,17 +182,5 @@ def _validate_login_ui(cls, value):
value = LoginUIConfig(_env_file=cls.Config.env_file)
return value

@validator('user_klass', pre=True, always=True)
def _validate_klass(cls, value):
if isinstance(value, str):
if ':' in value:
module_name, klass_name = value.split(':')
else:
split_path = value.split('.')
module_name = '.'.join(split_path[:-1])
klass_name = split_path[-1]
module = importlib.import_module(module_name)
value = getattr(module, klass_name)
return value

_validate_klass = validator('user_klass', pre=True, always=True, allow_reuse=True)(klass_from_str)
_validate_enabled = validator('enabled', allow_reuse=True)(bool_from_env)
Loading

0 comments on commit a6bd331

Please sign in to comment.