Skip to content

Commit

Permalink
Merge pull request #41 from djpugh/fix/scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
djpugh authored Dec 18, 2020
2 parents de2b4a5 + 10a8f6f commit 10d7b0e
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 80 deletions.
14 changes: 13 additions & 1 deletion docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The :class:`~fastapi_aad_auth.auth.Authenticator` object takes a ``user_klass``
:start-at: class Authenticator
:end-before: """Initialise

which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same
which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same
interface should work, so you can add e.g. database calls etc. to validate/persist/check the user and any other
desired behaviours.

Expand Down Expand Up @@ -65,3 +65,15 @@ These jinja templates also are structured (see :doc:`module/fastapi_aad_auth.ui`
:language: html

And can easily be extended or customised.


Token Scopes
~~~~~~~~~~~~

:mod:`fastapi_aad_auth` is used for providing authentication and authorisation on an API using Azure Active Directory as an authorisation provider.

This means that scopes are requested against the ``client_id`` of the application rather than e.g. MS Graph or similar, if your backend API needs to
access Microsoft (or other APIs) you will need to use e.g. an additional msal instance (or provide specific additional ``scopes`` through the
:py:meth:`fastapi_aad_auth.providers.aad.AADSessionAuthenticator.get_access_token`, with ``app_scopes=False``), if those permissions are added on the App Registration.

Alternatively, you can use an on-behalf-of flow (see `Azure Docs <https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow>`_).
8 changes: 7 additions & 1 deletion src/fastapi_aad_auth/_base/authenticators/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from starlette.responses import RedirectResponse

from fastapi_aad_auth._base.state import AuthenticationState
from fastapi_aad_auth.errors import ConfigurationError
from fastapi_aad_auth.mixins import LoggingMixin


Expand Down Expand Up @@ -49,6 +50,11 @@ def process_login_request(self, request, force=False, redirect='/'):

def process_login_callback(self, request):
"""Process the provider login callback."""
if 'error' in request.query_params:
error_args = [request.query_params['error'], ]
if 'error_description' in request.query_params:
error_args.append(request.query_params['error_description'])
raise ConfigurationError(*error_args)
code = request.query_params.get('code', None)
state = request.query_params.get('state', None)
if state is None or code is None:
Expand All @@ -64,7 +70,7 @@ def process_login_callback(self, request):
def _process_code(self, request, auth_state, code):
raise NotImplementedError('Implement in subclass')

def get_access_token(self, user):
def get_access_token(self, user, scopes=None, app_scopes=True):
"""Get the access token for the user."""
raise NotImplementedError('Implement in subclass')

Expand Down
29 changes: 21 additions & 8 deletions src/fastapi_aad_auth/_base/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from itsdangerous import URLSafeSerializer
from itsdangerous.exc import BadSignature
from pydantic import BaseModel, root_validator
from starlette.authentication import AuthCredentials, AuthenticationError, SimpleUser, UnauthenticatedUser
from pydantic import BaseModel, Field, root_validator, validator
from starlette.authentication import AuthCredentials, SimpleUser, UnauthenticatedUser

from fastapi_aad_auth.errors import AuthenticationError
from fastapi_aad_auth.mixins import LoggingMixin


Expand All @@ -24,16 +25,28 @@ class AuthenticationOptions(Enum):

class User(BaseModel):
"""User Model."""
name: str
email: str
username: str
roles: Optional[List[str]] = None
groups: Optional[List[str]] = None
name: str = Field(..., description='Full name')
email: str = Field(..., description='User email')
username: str = Field(..., description='Username')
roles: Optional[List[str]] = Field(None, description='Any roles provided')
groups: Optional[List[str]] = Field(None, description='Any groups provided')
scopes: Optional[List[str]] = Field(None, description='Token scopes provided')

@property
def permissions(self):
"""User Permissions."""
return []
permissions = []
if self.scopes:
for scope in self.scopes:
if not scope.startswith('.'):
permissions.append(scope)
return permissions[:]

@validator('scopes', always=True, pre=True)
def _validate_scopes(cls, value):
if isinstance(value, str):
value = value.split(' ')
return value


class AuthenticationState(LoggingMixin, BaseModel):
Expand Down
14 changes: 13 additions & 1 deletion src/fastapi_aad_auth/_base/validators/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Session based validator for interactive (UI) sessions."""
import fnmatch
from functools import partial
from typing import List, Optional

from itsdangerous import URLSafeSerializer

from fastapi_aad_auth._base.state import AuthenticationState
Expand All @@ -11,9 +15,10 @@
class SessionValidator(Validator):
"""Validator for session based authentication."""

def __init__(self, session_serializer: URLSafeSerializer, *args, **kwargs):
def __init__(self, session_serializer: URLSafeSerializer, ignore_redirect_routes: Optional[List[str]] = None, *args, **kwargs):
"""Initialise validator for session based authentication."""
self._session_serializer = session_serializer
self._ignore_redirect_routes = ignore_redirect_routes
super().__init__(*args, **kwargs) # type: ignore

def get_state_from_session(self, request):
Expand All @@ -36,8 +41,15 @@ def pop_post_auth_redirect(self, request):

def set_post_auth_redirect(self, request, redirect='/'):
"""Set post-authentication redirects."""
if not self.is_valid_redirect(redirect):
redirect = '/'

request.session[REDIRECT_KEY] = redirect

def is_valid_redirect(self, redirect):
"""Check if the redirect is not to endpoints that we don't want to redirect to."""
return not any(map(partial(fnmatch.fnmatch, redirect), self._ignore_redirect_routes))

@staticmethod
def get_session_serializer(secret, salt):
"""Get or Initialise the session serializer."""
Expand Down
37 changes: 26 additions & 11 deletions src/fastapi_aad_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import FastAPI
from starlette.authentication import requires
from starlette.middleware.authentication import AuthenticationError, AuthenticationMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
Expand All @@ -14,10 +14,10 @@
from fastapi_aad_auth._base.backend import BaseOAuthBackend
from fastapi_aad_auth._base.validators import SessionValidator
from fastapi_aad_auth.config import Config
from fastapi_aad_auth.errors import base_error_handler, ConfigurationError
from fastapi_aad_auth.errors import AuthenticationError, AuthorisationError, base_error_handler, ConfigurationError, json_error_handler, redirect_error_handler
from fastapi_aad_auth.mixins import LoggingMixin
from fastapi_aad_auth.ui.jinja import Jinja2Templates
from fastapi_aad_auth.utilities import deprecate
from fastapi_aad_auth.utilities import deprecate, is_interactive


_BASE_ROUTES = ['openapi', 'swagger_ui_html', 'swagger_ui_redirect', 'redoc_html']
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_
def _init_session_validator(self):
auth_serializer = SessionValidator.get_session_serializer(self.config.auth_session.secret.get_secret_value(),
self.config.auth_session.salt.get_secret_value())
return SessionValidator(auth_serializer)
return SessionValidator(auth_serializer, ignore_redirect_routes=self.config.routing.no_redirect_routes)
# Lets setup the oauth backend

def _init_providers(self):
Expand Down Expand Up @@ -114,13 +114,30 @@ async def configuration_error_handler(request: Request, exc: ConfigurationError)
status_code = 500
return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code)

@app.exception_handler(AuthenticationError)
async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response:
@app.exception_handler(AuthorisationError)
async def authorisation_error_handler(request: Request, exc: AuthorisationError) -> Response:
error_message = "Oops! It seems like you cannot access this information. If this is an error, please contact an admin"
error_type = 'Authentication Error'
error_type = 'Authorisation Error'
status_code = 403
return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code)

@app.exception_handler(AuthenticationError)
async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response:
return self._authentication_error_handler(request, exc)

def _authentication_error_handler(self, request: Request, exc: AuthenticationError) -> Response:
error_message = "Oops! It seems like you are not correctly authenticated"
status_code = 401
self.logger.exception(f'Error {exc} for request {request}')
if is_interactive(request):
self._session_validator.set_post_auth_redirect(request, request.url.path)
kwargs = {}
if self._session_validator.is_valid_redirect(request.url.path):
kwargs['redirect'] = request.url.path
return redirect_error_handler(self.config.routing.landing_path, exc, **kwargs)
else:
return json_error_handler(error_message, status_code=status_code)

def auth_required(self, scopes: str = 'authenticated', redirect: str = 'login'):
"""Decorator to require specific scopes (and redirect to the login ui) for an endpoint.
Expand Down Expand Up @@ -186,10 +203,8 @@ def configure_app(self, app: FastAPI, add_error_handlers=True):
Keyword Args:
add_error_handlers (bool) : add the error handlers to the app (default is true, but can be set to False to configure specific handling)
"""
def on_auth_error(request: Request, exc: Exception):
self.logger.exception(f'Error {exc} for request {request}')
self._session_validator.set_post_auth_redirect(request, request.url.path)
return RedirectResponse(self.config.routing.landing_path)
def on_auth_error(request: Request, exc: AuthenticationError):
return self._authentication_error_handler(request, exc)

app.add_middleware(AuthenticationMiddleware, backend=self. auth_backend, on_error=on_auth_error)
if add_error_handlers:
Expand Down
5 changes: 5 additions & 0 deletions src/fastapi_aad_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def _validate_post_logout_path(cls, value, values):
value = values.get('landing_path')
return value

@property
def no_redirect_routes(self):
"""Routes that we don't want to redirect to."""
return [self.login_path, self.login_redirect_path, f'{self.oauth_base_route}/*']


@expand_doc
class LoginUIConfig(BaseSettings):
Expand Down
79 changes: 57 additions & 22 deletions src/fastapi_aad_auth/errors.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,74 @@
"""fastapi_aad_auth errors."""
from starlette.responses import JSONResponse, Response
from pathlib import Path
from typing import Dict, Optional

from fastapi_aad_auth.utilities import is_interactive
from starlette.authentication import AuthenticationError
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.templating import Jinja2Templates

from fastapi_aad_auth.utilities import is_interactive, urls
from fastapi_aad_auth.utilities.logging import getLogger

logger = getLogger(__name__)


def base_error_handler(request, exception, error_type, error_message, templates, template_path, context=None, status_code=500) -> Response:
def base_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response:
"""Handle Error as JSON or HTML response depending on request type."""
if context is None:
context = {}
logger.warning(f'Handling error {exception}')
if is_interactive(request):
logger.info('Interactive environment so returning error template')
logger.debug(f'Path: {template_path}')
error_context = context.copy()
error_context.update({'error': str(exception),
'status_code': str(status_code),
'error_type': error_type,
'error_description': error_message,
'request': request}) # type: ignore
response = templates.TemplateResponse(template_path.name,
error_context,
status_code=status_code)
response = ui_error_handler(request, exception, error_type, error_message, templates, template_path, context, status_code)
else:
logger.info('Non-Interactive environment so returning JSON message')

response = JSONResponse( # type: ignore
status_code=status_code,
content={"message": error_message}
)
response = json_error_handler(error_message, status_code)
logger.debug(f'Response {response}')
return response


def json_error_handler(error_message: str, status_code: int = 500) -> JSONResponse:
"""Handle error as a JSON."""
logger.info('Non-Interactive environment so returning JSON message')

return JSONResponse( # type: ignore
status_code=status_code,
content={"message": error_message}
)


def redirect_error_handler(redirect_path: str, exception: Exception, **kwargs) -> RedirectResponse:
"""Handle error as a redirect with error info in the query parameters."""
return RedirectResponse(urls.with_query_params(redirect_path, error=exception, **kwargs))


def ui_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response:
"""Return a UI view of the error."""
logger.info('Interactive environment so returning error template')
logger.debug(f'Path: {template_path}')
logger.debug(f'Exception: {exception}')
if context is None:
context = {}
error_context = context.copy()
error = exception
detail = ''
if exception.args:
logger.info('Getting args')
error = exception.args[0]
if len(exception.args) > 1:
detail = exception.args[1]
error_context.update({'error': str(error),
'status_code': str(status_code),
'error_type': error_type,
'error_description': error_message,
'error_detail': str(detail),
'request': request}) # type: ignore
logger.debug(f'Error context: {error_context}')
return templates.TemplateResponse(template_path.name,
error_context,
status_code=status_code)


class ConfigurationError(Exception):
"""Misconfigured application."""


class AuthorisationError(AuthenticationError):
"""Not Authorised to access this resource."""
Loading

0 comments on commit 10d7b0e

Please sign in to comment.