From ab89787e28b5a04ff60d48dcdcc86d3d8f1a4051 Mon Sep 17 00:00:00 2001 From: David Pugh Date: Fri, 18 Dec 2020 17:47:12 +0000 Subject: [PATCH 1/2] Adding scopes handling/inclusion in tickets --- docs/source/advanced.rst | 14 +++- .../_base/authenticators/session.py | 8 +- src/fastapi_aad_auth/_base/state.py | 29 +++++-- .../_base/validators/session.py | 14 +++- src/fastapi_aad_auth/auth.py | 37 ++++++--- src/fastapi_aad_auth/config.py | 5 ++ src/fastapi_aad_auth/errors.py | 79 +++++++++++++------ src/fastapi_aad_auth/providers/aad.py | 76 ++++++++++++------ src/fastapi_aad_auth/ui/__init__.py | 33 +++++--- src/fastapi_aad_auth/ui/error.html | 6 +- src/fastapi_aad_auth/utilities/urls.py | 19 ++++- 11 files changed, 240 insertions(+), 80 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 7815102..4b12bfb 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -36,7 +36,7 @@ The :class:`~fastapi_aad_auth.auth.Authenticator` object takes a ``user_klass`` :start-at: class Authenticator :end-before: """Initialise -which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same +which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same interface should work, so you can add e.g. database calls etc. to validate/persist/check the user and any other desired behaviours. @@ -65,3 +65,15 @@ These jinja templates also are structured (see :doc:`module/fastapi_aad_auth.ui` :language: html And can easily be extended or customised. + + +Token Scopes +~~~~~~~~~~~~ + +:mod:`fastapi_aad_auth` is used for providing authentication and authorisation on an API using Azure Active Directory as an authorisation provider. + +This means that scopes are requested against the ``client_id`` of the application rather than e.g. MS Graph or similar, if your backend API needs to +access Microsoft (or other APIs) you will need to use e.g. an additional msal instance (or provide specific additional ``scopes`` through the +:py:method:`fastapi_aad_auth._base.authenticators.session.get_access_token`, with ``app_scopes=False``), if those permissions are added on the App Registration. + +Alternatively, you can use an on-behalf-of flow (see `Azure Docs `_). diff --git a/src/fastapi_aad_auth/_base/authenticators/session.py b/src/fastapi_aad_auth/_base/authenticators/session.py index 6330301..c47af78 100644 --- a/src/fastapi_aad_auth/_base/authenticators/session.py +++ b/src/fastapi_aad_auth/_base/authenticators/session.py @@ -3,6 +3,7 @@ from starlette.responses import RedirectResponse from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth.errors import ConfigurationError from fastapi_aad_auth.mixins import LoggingMixin @@ -49,6 +50,11 @@ def process_login_request(self, request, force=False, redirect='/'): def process_login_callback(self, request): """Process the provider login callback.""" + if 'error' in request.query_params: + error_args = [request.query_params['error'], ] + if 'error_description' in request.query_params: + error_args.append(request.query_params['error_description']) + raise ConfigurationError(*error_args) code = request.query_params.get('code', None) state = request.query_params.get('state', None) if state is None or code is None: @@ -64,7 +70,7 @@ def process_login_callback(self, request): def _process_code(self, request, auth_state, code): raise NotImplementedError('Implement in subclass') - def get_access_token(self, user): + def get_access_token(self, user, scopes=None): """Get the access token for the user.""" raise NotImplementedError('Implement in subclass') diff --git a/src/fastapi_aad_auth/_base/state.py b/src/fastapi_aad_auth/_base/state.py index 9f9f177..e8207bf 100644 --- a/src/fastapi_aad_auth/_base/state.py +++ b/src/fastapi_aad_auth/_base/state.py @@ -6,9 +6,10 @@ from itsdangerous import URLSafeSerializer from itsdangerous.exc import BadSignature -from pydantic import BaseModel, root_validator -from starlette.authentication import AuthCredentials, AuthenticationError, SimpleUser, UnauthenticatedUser +from pydantic import BaseModel, Field, root_validator, validator +from starlette.authentication import AuthCredentials, SimpleUser, UnauthenticatedUser +from fastapi_aad_auth.errors import AuthenticationError from fastapi_aad_auth.mixins import LoggingMixin @@ -24,16 +25,28 @@ class AuthenticationOptions(Enum): class User(BaseModel): """User Model.""" - name: str - email: str - username: str - roles: Optional[List[str]] = None - groups: Optional[List[str]] = None + name: str = Field(..., description='Full name') + email: str = Field(..., description='User email') + username: str = Field(..., description='Username') + roles: Optional[List[str]] = Field(None, description='Any roles provided') + groups: Optional[List[str]] = Field(None, description='Any groups provided') + scopes: Optional[List[str]] = Field(None, description='Token scopes provided') @property def permissions(self): """User Permissions.""" - return [] + permissions = [] + if self.scopes: + for scope in self.scopes: + if not scope.startswith('.'): + permissions.append(scope) + return permissions[:] + + @validator('scopes', always=True, pre=True) + def _validate_scopes(cls, value): + if isinstance(value, str): + value = value.split(' ') + return value class AuthenticationState(LoggingMixin, BaseModel): diff --git a/src/fastapi_aad_auth/_base/validators/session.py b/src/fastapi_aad_auth/_base/validators/session.py index c3cd58a..0ee74a6 100644 --- a/src/fastapi_aad_auth/_base/validators/session.py +++ b/src/fastapi_aad_auth/_base/validators/session.py @@ -1,4 +1,8 @@ """Session based validator for interactive (UI) sessions.""" +import fnmatch +from functools import partial +from typing import List, Optional + from itsdangerous import URLSafeSerializer from fastapi_aad_auth._base.state import AuthenticationState @@ -11,9 +15,10 @@ class SessionValidator(Validator): """Validator for session based authentication.""" - def __init__(self, session_serializer: URLSafeSerializer, *args, **kwargs): + def __init__(self, session_serializer: URLSafeSerializer, ignore_redirect_routes: Optional[List[str]] = None, *args, **kwargs): """Initialise validator for session based authentication.""" self._session_serializer = session_serializer + self._ignore_redirect_routes = ignore_redirect_routes super().__init__(*args, **kwargs) # type: ignore def get_state_from_session(self, request): @@ -36,8 +41,15 @@ def pop_post_auth_redirect(self, request): def set_post_auth_redirect(self, request, redirect='/'): """Set post-authentication redirects.""" + if not self.is_valid_redirect(redirect): + redirect = '/' + request.session[REDIRECT_KEY] = redirect + def is_valid_redirect(self, redirect): + """Check if the redirect is not to endpoints that we don't want to redirect to.""" + return not any(map(partial(fnmatch.fnmatch, redirect), self._ignore_redirect_routes)) + @staticmethod def get_session_serializer(secret, salt): """Get or Initialise the session serializer.""" diff --git a/src/fastapi_aad_auth/auth.py b/src/fastapi_aad_auth/auth.py index 6c66fa4..3d480ba 100644 --- a/src/fastapi_aad_auth/auth.py +++ b/src/fastapi_aad_auth/auth.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from starlette.authentication import requires -from starlette.middleware.authentication import AuthenticationError, AuthenticationMiddleware +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.requests import Request from starlette.responses import RedirectResponse, Response @@ -14,10 +14,10 @@ from fastapi_aad_auth._base.backend import BaseOAuthBackend from fastapi_aad_auth._base.validators import SessionValidator from fastapi_aad_auth.config import Config -from fastapi_aad_auth.errors import base_error_handler, ConfigurationError +from fastapi_aad_auth.errors import AuthenticationError, AuthorisationError, base_error_handler, ConfigurationError, json_error_handler, redirect_error_handler from fastapi_aad_auth.mixins import LoggingMixin from fastapi_aad_auth.ui.jinja import Jinja2Templates -from fastapi_aad_auth.utilities import deprecate +from fastapi_aad_auth.utilities import deprecate, is_interactive _BASE_ROUTES = ['openapi', 'swagger_ui_html', 'swagger_ui_redirect', 'redoc_html'] @@ -66,7 +66,7 @@ def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_ def _init_session_validator(self): auth_serializer = SessionValidator.get_session_serializer(self.config.auth_session.secret.get_secret_value(), self.config.auth_session.salt.get_secret_value()) - return SessionValidator(auth_serializer) + return SessionValidator(auth_serializer, ignore_redirect_routes=self.config.routing.no_redirect_routes) # Lets setup the oauth backend def _init_providers(self): @@ -114,13 +114,30 @@ async def configuration_error_handler(request: Request, exc: ConfigurationError) status_code = 500 return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) - @app.exception_handler(AuthenticationError) - async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response: + @app.exception_handler(AuthorisationError) + async def authorisation_error_handler(request: Request, exc: AuthorisationError) -> Response: error_message = "Oops! It seems like you cannot access this information. If this is an error, please contact an admin" - error_type = 'Authentication Error' + error_type = 'Authorisation Error' status_code = 403 return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) + @app.exception_handler(AuthenticationError) + async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response: + return self._authentication_error_handler(request, exc) + + def _authentication_error_handler(self, request: Request, exc: AuthenticationError) -> Response: + error_message = "Oops! It seems like you are not correctly authenticated" + status_code = 401 + self.logger.exception(f'Error {exc} for request {request}') + if is_interactive(request): + self._session_validator.set_post_auth_redirect(request, request.url.path) + kwargs = {} + if self._session_validator.is_valid_redirect(request.url.path): + kwargs['redirect'] = request.url.path + return redirect_error_handler(self.config.routing.landing_path, exc, **kwargs) + else: + return json_error_handler(error_message, status_code=status_code) + def auth_required(self, scopes: str = 'authenticated', redirect: str = 'login'): """Decorator to require specific scopes (and redirect to the login ui) for an endpoint. @@ -186,10 +203,8 @@ def configure_app(self, app: FastAPI, add_error_handlers=True): Keyword Args: add_error_handlers (bool) : add the error handlers to the app (default is true, but can be set to False to configure specific handling) """ - def on_auth_error(request: Request, exc: Exception): - self.logger.exception(f'Error {exc} for request {request}') - self._session_validator.set_post_auth_redirect(request, request.url.path) - return RedirectResponse(self.config.routing.landing_path) + def on_auth_error(request: Request, exc: AuthenticationError): + return self._authentication_error_handler(request, exc) app.add_middleware(AuthenticationMiddleware, backend=self. auth_backend, on_error=on_auth_error) if add_error_handlers: diff --git a/src/fastapi_aad_auth/config.py b/src/fastapi_aad_auth/config.py index 30c48a6..96b9dc2 100644 --- a/src/fastapi_aad_auth/config.py +++ b/src/fastapi_aad_auth/config.py @@ -48,6 +48,11 @@ def _validate_post_logout_path(cls, value, values): value = values.get('landing_path') return value + @property + def no_redirect_routes(self): + """Routes that we don't want to redirect to.""" + return [self.login_path, self.login_redirect_path, f'{self.oauth_base_route}/*'] + @expand_doc class LoginUIConfig(BaseSettings): diff --git a/src/fastapi_aad_auth/errors.py b/src/fastapi_aad_auth/errors.py index 3139b7c..5131618 100644 --- a/src/fastapi_aad_auth/errors.py +++ b/src/fastapi_aad_auth/errors.py @@ -1,39 +1,74 @@ """fastapi_aad_auth errors.""" -from starlette.responses import JSONResponse, Response +from pathlib import Path +from typing import Dict, Optional -from fastapi_aad_auth.utilities import is_interactive +from starlette.authentication import AuthenticationError +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.templating import Jinja2Templates + +from fastapi_aad_auth.utilities import is_interactive, urls from fastapi_aad_auth.utilities.logging import getLogger logger = getLogger(__name__) -def base_error_handler(request, exception, error_type, error_message, templates, template_path, context=None, status_code=500) -> Response: +def base_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response: """Handle Error as JSON or HTML response depending on request type.""" - if context is None: - context = {} logger.warning(f'Handling error {exception}') if is_interactive(request): - logger.info('Interactive environment so returning error template') - logger.debug(f'Path: {template_path}') - error_context = context.copy() - error_context.update({'error': str(exception), - 'status_code': str(status_code), - 'error_type': error_type, - 'error_description': error_message, - 'request': request}) # type: ignore - response = templates.TemplateResponse(template_path.name, - error_context, - status_code=status_code) + response = ui_error_handler(request, exception, error_type, error_message, templates, template_path, context, status_code) else: - logger.info('Non-Interactive environment so returning JSON message') - - response = JSONResponse( # type: ignore - status_code=status_code, - content={"message": error_message} - ) + response = json_error_handler(error_message, status_code) logger.debug(f'Response {response}') return response +def json_error_handler(error_message: str, status_code: int = 500) -> JSONResponse: + """Handle error as a JSON.""" + logger.info('Non-Interactive environment so returning JSON message') + + return JSONResponse( # type: ignore + status_code=status_code, + content={"message": error_message} + ) + + +def redirect_error_handler(redirect_path: str, exception: Exception, **kwargs) -> RedirectResponse: + """Handle error as a redirect with error info in the query parameters.""" + return RedirectResponse(urls.with_query_params(redirect_path, error=exception, **kwargs)) + + +def ui_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response: + """Return a UI view of the error.""" + logger.info('Interactive environment so returning error template') + logger.debug(f'Path: {template_path}') + logger.debug(f'Exception: {exception}') + if context is None: + context = {} + error_context = context.copy() + error = exception + detail = '' + if exception.args: + logger.info('Getting args') + error = exception.args[0] + if len(exception.args) > 1: + detail = exception.args[1] + error_context.update({'error': str(error), + 'status_code': str(status_code), + 'error_type': error_type, + 'error_description': error_message, + 'error_detail': str(detail), + 'request': request}) # type: ignore + logger.debug(f'Error context: {error_context}') + return templates.TemplateResponse(template_path.name, + error_context, + status_code=status_code) + + class ConfigurationError(Exception): """Misconfigured application.""" + + +class AuthorisationError(AuthenticationError): + """Not Authorised to access this resource.""" diff --git a/src/fastapi_aad_auth/providers/aad.py b/src/fastapi_aad_auth/providers/aad.py index 3d2f9da..a59469d 100644 --- a/src/fastapi_aad_auth/providers/aad.py +++ b/src/fastapi_aad_auth/providers/aad.py @@ -1,7 +1,8 @@ """AAD OAuth handlers.""" import base64 -from typing import List, Optional +from enum import Enum +from typing import List, Optional, Union from authlib.jose import errors as jwt_errors, jwk, jwt from authlib.jose.util import extract_header @@ -21,6 +22,12 @@ from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, expand_doc, is_deprecated, list_from_env, urls +class TokenType(Enum): + """Type of token to use.""" + access = 'access_token' + id = 'id_token' + + class BaseSettings(DeprecatableFieldsMixin, _BaseSettings): """Base Settings with Deprecatable Fields.""" pass @@ -40,7 +47,8 @@ def __init__( client_secret=None, scopes=None, redirect_uri=None, - domain_hint=None): + domain_hint=None, + token_type=TokenType.access): """Initialise AAD Authenticator for interactive (UI) sessions.""" super().__init__(session_validator, token_validator) self._redirect_path = redirect_path @@ -49,10 +57,12 @@ def __init__( self._prompt = prompt self.client_id = client_id if scopes is None: - scopes = [f'api://{self.client_id}'] - elif isinstance(scopes, str): - scopes = [scopes] + scopes = [f'{self.client_id}/openid'] + self.logger.info(f'Scopes {scopes}') self._scopes = scopes + if isinstance(token_type, Enum): + token_type = token_type.value + self._token_type = token_type self._authority = f'https://login.microsoftonline.com/{tenant_id}' if client_secret is not None: @@ -80,12 +90,15 @@ def _build_redirect_uri(self, request: Request): def _process_code(self, request: Request, auth_state, code): # Let's build up the redirect_uri - result = self.msal_application.acquire_token_by_authorization_code(code, scopes=[], + result = self.msal_application.acquire_token_by_authorization_code(code, scopes=self._scopes, redirect_uri=self._build_redirect_uri(request)) self.logger.debug(f'Result {result}') if 'error' in result and result['error']: - raise ConfigurationError(result) - return result['id_token'] + error_args = [result['error']] + if 'error_description' in result: + error_args.append(result['error_description']) + raise ConfigurationError(*error_args) + return result[self._token_type] def _get_user_from_token(self, token, options=None): if options is None: @@ -95,32 +108,44 @@ def _get_user_from_token(self, token, options=None): return super()._get_user_from_token(token, options=options) def _get_authorization_url(self, request, session_state): - return self.msal_application.get_authorization_request_url([], + return self.msal_application.get_authorization_request_url(self._scopes, state=session_state, claims_challenge='{"id_token": {"roles": {"essential": true} } }', redirect_uri=self._build_redirect_uri(request), prompt=self._prompt, domain_hint=self._domain_hint) - def get_access_token(self, user): + def get_access_token(self, user, scopes=None, _app_scopes=True): """Get the access token for the user.""" result = None account = None + if scopes is None: + scopes = self._scopes + elif _app_scopes: + scopes = self.as_app_scopes(scopes) if user.username: account = self.msal_application.get_accounts(user.username) if account: account = account[0] self.logger.info(account) - # This needs you to register the openid api - result = self.msal_application.acquire_token_silent_with_error(scopes=[f'api://{self.client_id}/openid'], account=account) - self.logger.info(result) + # This needs you to register the scopes in the app registration + result = self.msal_application.acquire_token_silent_with_error(scopes=scopes, account=account) + self.logger.info(f'Acquired Token: {result}') if result is None: raise ValueError('Token not found') + elif 'error' in result: + raise ConfigurationError(result['error'], result['error_description']) else: return {'token_type': result['token_type'], 'expires_in': result['expires_in'], 'access_token': result['access_token']} + def as_app_scopes(self, scopes): + """Add the application client id to the scopes so that the tokens are valid for this app.""" + if self.client_id not in scopes[0]: + scopes[0] = f'{self.client_id}/{scopes[0]}' + return scopes + class AADTokenValidator(TokenValidator): """Validator for AAD token based authentication.""" @@ -169,12 +194,13 @@ def _decode_token(self, token): self.logger.debug(f'Key is {jwk_}') try: if hasattr(jwk, 'public_bytes'): - public_bytes = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + key = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) else: - public_bytes = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + key = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + self.logger.debug(f'Processed Key: {key}') claims = jwt.decode( token, - public_bytes, + key, ) except Exception: self.logger.exception('Unable to parse error') @@ -186,7 +212,6 @@ def _validate_claims(self, claims, options=None): options = self._claims_options # We need to do some 1.0/2.0 handling because it doesn't seem to work properly # TODO: validate whether we want this claim here? - # TODO: validate whether the user is approved for the app if 'appid' in options and 'azp' in options: if 'appid' not in claims: options.pop('appid') @@ -217,10 +242,10 @@ def _get_user_from_claims(self, claims): username_key = 'unique_name' if 'name' not in claims and 'appid' in claims: # This is an application/service principal - return self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None)) - + user = self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None)) else: - return self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None)) + user = self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None)) + return user class AADProvider(Provider): @@ -242,7 +267,8 @@ def __init__( redirect_uri: Optional[str] = None, domain_hint: Optional[str] = None, user_klass: type = User, - oauth_base_route: str = '/oauth'): + oauth_base_route: str = '/oauth', + token_type: Union[str, TokenType] = TokenType.access): """Initialise the auth backend. Args: @@ -269,7 +295,7 @@ def __init__( session_authenticator = AADSessionAuthenticator(session_validator=session_validator, token_validator=token_validator, client_id=client_id, tenant_id=tenant_id, redirect_path=redirect_path, prompt=prompt, client_secret=client_secret, scopes=scopes, - redirect_uri=redirect_uri, domain_hint=domain_hint) + redirect_uri=redirect_uri, domain_hint=domain_hint, token_type=token_type) super().__init__(validators=[token_validator], authenticator=session_authenticator, enabled=enabled, oauth_base_route=oauth_base_route) @classmethod @@ -296,7 +322,8 @@ def from_config(cls, session_validator, config, provider_config, user_klass: Opt scopes=provider_config.scopes, client_app_ids=provider_config.client_app_ids, strict_token=provider_config.strict, api_audience=provider_config.api_audience, prompt=provider_config.prompt, domain_hint=provider_config.domain_hint, - redirect_uri=provider_config.redirect_uri, user_klass=user_klass, oauth_base_route=config.routing.oauth_base_route) + redirect_uri=provider_config.redirect_uri, user_klass=user_klass, + oauth_base_route=config.routing.oauth_base_route, token_type=provider_config.token_type) # We need to override the login and redirect etc until it is deprecated if hasattr(config.routing, 'login_path') and config.routing.login_path and not is_deprecated(config.routing.__fields__['login_path']): obj._login_url = config.routing.login_path @@ -326,7 +353,7 @@ class AADConfig(BaseSettings): client_id: SecretStr = Field(..., description="Application Registration Client ID", env='AAD_CLIENT_ID') tenant_id: SecretStr = Field(..., description="Application Registration Tenant ID", env='AAD_TENANT_ID') client_secret: Optional[SecretStr] = Field(None, description="Application Registration Client Secret (if required)", env='AAD_CLIENT_SECRET') - scopes: List[str] = Field(["Read"], description="Additional scopes requested") + scopes: Optional[List[str]] = Field(None, description="Additional scopes requested - if the scope is not configured to the application this will throw an error when validating the token") client_app_ids: Optional[List[str]] = Field(None, description="Additional Client App IDs to accept tokens from (when running as a backend service)", env='AAD_CLIENT_APP_IDS') strict: bool = Field(True, description="Check that all claims are provided", env='AAD_STRICT_CLAIM_CHECK') @@ -337,6 +364,7 @@ class AADConfig(BaseSettings): prompt: Optional[str] = Field(None, description="AAD prompt to request", env='AAD_PROMPT') domain_hint: Optional[str] = Field(None, description="AAD domain hint", env='AAD_DOMAIN_HINT') roles: Optional[List[str]] = Field(None, description="AAD roles required in claims", env='AAD_ROLES') + token_type: TokenType = Field(TokenType.access, description='The AAD token type to use to validate (we should use the access token if it is configured, unless we are acting as a pure UI component') _provider_klass: type = PrivateAttr(AADProvider) class Config: # noqa D106 diff --git a/src/fastapi_aad_auth/ui/__init__.py b/src/fastapi_aad_auth/ui/__init__.py index 93e52dc..76c02ef 100644 --- a/src/fastapi_aad_auth/ui/__init__.py +++ b/src/fastapi_aad_auth/ui/__init__.py @@ -8,7 +8,7 @@ * ``user.html``: View the user's information and get an access token """ from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Optional from fastapi import Depends from starlette.requests import Request @@ -21,6 +21,7 @@ from fastapi_aad_auth._base.state import AuthenticationState from fastapi_aad_auth.mixins import LoggingMixin from fastapi_aad_auth.ui.jinja import Jinja2Templates +from fastapi_aad_auth.utilities import urls class UI(LoggingMixin): @@ -81,14 +82,18 @@ def _get_user(self, request: Request, **kwargs): context['token_api_path'] = None # type: ignore return self.user_templates.TemplateResponse(self.user_template_path.name, context) - def _get_token(self, request: Request, auth_state: AuthenticationState): + def _get_token(self, request: Request, auth_state: AuthenticationState, scopes: Optional[List[str]] = None): """Return the access token for the user.""" if not isinstance(auth_state, AuthenticationState): user = self.__get_user_from_request(request) else: user = auth_state.user if hasattr(user, 'username'): # type: ignore - access_token = self.__get_access_token(user) + if scopes is None: + scopes = request.query_params.get('scopes', None) + if isinstance(scopes, str): + scopes = scopes.split(' ') # type: ignore + access_token = self.__get_access_token(user, scopes) if access_token: # We want to get the token for each provider that is authenticated return JSONResponse(access_token) # type: ignore @@ -98,7 +103,11 @@ def _get_token(self, request: Request, auth_state: AuthenticationState): return self.__force_authenticate(request) else: return JSONResponse('Unable to access token as user has not authenticated via session') - return RedirectResponse(f'{self.config.routing.landing_path}?redirect=/me/token') + redirect = '/me/token' + if scopes: + self.logger.debug(f'Getting Access Token with scopes {scopes}') + redirect = urls.with_query_params(redirect, scopes=scopes) + return RedirectResponse(urls.with_query_params(self.config.routing.landing_path, redirect=redirect)) @property def routes(self): @@ -120,8 +129,8 @@ async def login(request: Request, *args, **kwargs): async def get_user(request: Request): return self._get_user(request) - async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend.requires_auth(allow_session=True))): - return self._get_token(request, auth_state) + async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend.requires_auth(allow_session=True)), scopes: Optional[List[str]] = None): + return self._get_token(request, auth_state, scopes) routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'), Route(f'{self.config.routing.user_path}/token', endpoint=get_token, methods=['GET'], name='get-token')] @@ -129,16 +138,20 @@ async def get_token(request: Request, auth_state: AuthenticationState = Depends( return routes def __force_authenticate(self, request: Request): + # lets get the full redirect including any query parameters + redirect = urls.with_query_params(request.url.path, **request.query_params) + self.logger.debug(f'Request {request.url}') + self.logger.info(f'Forcing authentication with redirect = {redirect}') if len(self._authenticator._providers) == 1: - return self._authenticator._providers[0].authenticator.process_login_request(request, force=True, redirect=request.url.path) + return self._authenticator._providers[0].authenticator.process_login_request(request, force=True, redirect=redirect) else: - return RedirectResponse(f'{self.config.routing.landing_path}?redirect={request.url.path}') + return RedirectResponse(urls.with_query_params(self.config.routing.landing_path, redirect=redirect)) - def __get_access_token(self, user): + def __get_access_token(self, user, scopes=None): access_token = None for provider in self._authenticator._providers: try: - access_token = provider.authenticator.get_access_token(user) + access_token = provider.authenticator.get_access_token(user, scopes) except ValueError: pass if access_token is not None: diff --git a/src/fastapi_aad_auth/ui/error.html b/src/fastapi_aad_auth/ui/error.html index 15b9641..f7ca418 100644 --- a/src/fastapi_aad_auth/ui/error.html +++ b/src/fastapi_aad_auth/ui/error.html @@ -21,7 +21,11 @@

{{error_type}}

{{ error | safe }} - + {% if error_detail %} +
+ {{ error_detail }} + {% endif %} + {% endblock Error %} {% block ContentFooter %} diff --git a/src/fastapi_aad_auth/utilities/urls.py b/src/fastapi_aad_auth/utilities/urls.py index b0f95ab..c1369af 100644 --- a/src/fastapi_aad_auth/utilities/urls.py +++ b/src/fastapi_aad_auth/utilities/urls.py @@ -1,10 +1,16 @@ """URL utilities.""" +import logging + +from starlette.datastructures import URL + + +logger = logging.getLogger(__name__) def with_redirect(url, post_redirect=None): """Append a redirect query parameter.""" if post_redirect is not None: - url = f'{url}?redirect={post_redirect}' + url = with_query_params(url, redirect=post_redirect) return url @@ -16,3 +22,14 @@ def append(base_url, *args): else: url = base_url return url + + +def with_query_params(url, **query_params): + """Add query parameters to a url.""" + logger.debug(f'Adding {query_params} to {url}') + parsed_url = URL(url) + logger.debug(f'Existing query params {parsed_url.query}') + new_url = parsed_url.include_query_params(**query_params) + logger.debug(f'Updated query params {new_url.query}') + logger.debug(f'Updated url {new_url}') + return str(new_url) From 10a8f6f4c5249d336130a8fe8cddcf86ede940d9 Mon Sep 17 00:00:00 2001 From: David Pugh Date: Fri, 18 Dec 2020 17:52:37 +0000 Subject: [PATCH 2/2] Fixing docs --- docs/source/advanced.rst | 2 +- src/fastapi_aad_auth/_base/authenticators/session.py | 2 +- src/fastapi_aad_auth/providers/aad.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 4b12bfb..e8e951e 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -74,6 +74,6 @@ Token Scopes This means that scopes are requested against the ``client_id`` of the application rather than e.g. MS Graph or similar, if your backend API needs to access Microsoft (or other APIs) you will need to use e.g. an additional msal instance (or provide specific additional ``scopes`` through the -:py:method:`fastapi_aad_auth._base.authenticators.session.get_access_token`, with ``app_scopes=False``), if those permissions are added on the App Registration. +:py:meth:`fastapi_aad_auth.providers.aad.AADSessionAuthenticator.get_access_token`, with ``app_scopes=False``), if those permissions are added on the App Registration. Alternatively, you can use an on-behalf-of flow (see `Azure Docs `_). diff --git a/src/fastapi_aad_auth/_base/authenticators/session.py b/src/fastapi_aad_auth/_base/authenticators/session.py index c47af78..833ac27 100644 --- a/src/fastapi_aad_auth/_base/authenticators/session.py +++ b/src/fastapi_aad_auth/_base/authenticators/session.py @@ -70,7 +70,7 @@ def process_login_callback(self, request): def _process_code(self, request, auth_state, code): raise NotImplementedError('Implement in subclass') - def get_access_token(self, user, scopes=None): + def get_access_token(self, user, scopes=None, app_scopes=True): """Get the access token for the user.""" raise NotImplementedError('Implement in subclass') diff --git a/src/fastapi_aad_auth/providers/aad.py b/src/fastapi_aad_auth/providers/aad.py index a59469d..653c075 100644 --- a/src/fastapi_aad_auth/providers/aad.py +++ b/src/fastapi_aad_auth/providers/aad.py @@ -115,13 +115,13 @@ def _get_authorization_url(self, request, session_state): prompt=self._prompt, domain_hint=self._domain_hint) - def get_access_token(self, user, scopes=None, _app_scopes=True): + def get_access_token(self, user, scopes=None, app_scopes=True): """Get the access token for the user.""" result = None account = None if scopes is None: scopes = self._scopes - elif _app_scopes: + elif app_scopes: scopes = self.as_app_scopes(scopes) if user.username: account = self.msal_application.get_accounts(user.username)