diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 7815102..e8e951e 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -36,7 +36,7 @@ The :class:`~fastapi_aad_auth.auth.Authenticator` object takes a ``user_klass`` :start-at: class Authenticator :end-before: """Initialise -which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same +which defaults to the really basic :class:`~fastapi_aad_auth.oauth.state.User` class, but any object with the same interface should work, so you can add e.g. database calls etc. to validate/persist/check the user and any other desired behaviours. @@ -65,3 +65,15 @@ These jinja templates also are structured (see :doc:`module/fastapi_aad_auth.ui` :language: html And can easily be extended or customised. + + +Token Scopes +~~~~~~~~~~~~ + +:mod:`fastapi_aad_auth` is used for providing authentication and authorisation on an API using Azure Active Directory as an authorisation provider. + +This means that scopes are requested against the ``client_id`` of the application rather than e.g. MS Graph or similar, if your backend API needs to +access Microsoft (or other APIs) you will need to use e.g. an additional msal instance (or provide specific additional ``scopes`` through the +:py:meth:`fastapi_aad_auth.providers.aad.AADSessionAuthenticator.get_access_token`, with ``app_scopes=False``), if those permissions are added on the App Registration. + +Alternatively, you can use an on-behalf-of flow (see `Azure Docs `_). diff --git a/src/fastapi_aad_auth/_base/authenticators/session.py b/src/fastapi_aad_auth/_base/authenticators/session.py index 6330301..833ac27 100644 --- a/src/fastapi_aad_auth/_base/authenticators/session.py +++ b/src/fastapi_aad_auth/_base/authenticators/session.py @@ -3,6 +3,7 @@ from starlette.responses import RedirectResponse from fastapi_aad_auth._base.state import AuthenticationState +from fastapi_aad_auth.errors import ConfigurationError from fastapi_aad_auth.mixins import LoggingMixin @@ -49,6 +50,11 @@ def process_login_request(self, request, force=False, redirect='/'): def process_login_callback(self, request): """Process the provider login callback.""" + if 'error' in request.query_params: + error_args = [request.query_params['error'], ] + if 'error_description' in request.query_params: + error_args.append(request.query_params['error_description']) + raise ConfigurationError(*error_args) code = request.query_params.get('code', None) state = request.query_params.get('state', None) if state is None or code is None: @@ -64,7 +70,7 @@ def process_login_callback(self, request): def _process_code(self, request, auth_state, code): raise NotImplementedError('Implement in subclass') - def get_access_token(self, user): + def get_access_token(self, user, scopes=None, app_scopes=True): """Get the access token for the user.""" raise NotImplementedError('Implement in subclass') diff --git a/src/fastapi_aad_auth/_base/state.py b/src/fastapi_aad_auth/_base/state.py index 9f9f177..e8207bf 100644 --- a/src/fastapi_aad_auth/_base/state.py +++ b/src/fastapi_aad_auth/_base/state.py @@ -6,9 +6,10 @@ from itsdangerous import URLSafeSerializer from itsdangerous.exc import BadSignature -from pydantic import BaseModel, root_validator -from starlette.authentication import AuthCredentials, AuthenticationError, SimpleUser, UnauthenticatedUser +from pydantic import BaseModel, Field, root_validator, validator +from starlette.authentication import AuthCredentials, SimpleUser, UnauthenticatedUser +from fastapi_aad_auth.errors import AuthenticationError from fastapi_aad_auth.mixins import LoggingMixin @@ -24,16 +25,28 @@ class AuthenticationOptions(Enum): class User(BaseModel): """User Model.""" - name: str - email: str - username: str - roles: Optional[List[str]] = None - groups: Optional[List[str]] = None + name: str = Field(..., description='Full name') + email: str = Field(..., description='User email') + username: str = Field(..., description='Username') + roles: Optional[List[str]] = Field(None, description='Any roles provided') + groups: Optional[List[str]] = Field(None, description='Any groups provided') + scopes: Optional[List[str]] = Field(None, description='Token scopes provided') @property def permissions(self): """User Permissions.""" - return [] + permissions = [] + if self.scopes: + for scope in self.scopes: + if not scope.startswith('.'): + permissions.append(scope) + return permissions[:] + + @validator('scopes', always=True, pre=True) + def _validate_scopes(cls, value): + if isinstance(value, str): + value = value.split(' ') + return value class AuthenticationState(LoggingMixin, BaseModel): diff --git a/src/fastapi_aad_auth/_base/validators/session.py b/src/fastapi_aad_auth/_base/validators/session.py index c3cd58a..0ee74a6 100644 --- a/src/fastapi_aad_auth/_base/validators/session.py +++ b/src/fastapi_aad_auth/_base/validators/session.py @@ -1,4 +1,8 @@ """Session based validator for interactive (UI) sessions.""" +import fnmatch +from functools import partial +from typing import List, Optional + from itsdangerous import URLSafeSerializer from fastapi_aad_auth._base.state import AuthenticationState @@ -11,9 +15,10 @@ class SessionValidator(Validator): """Validator for session based authentication.""" - def __init__(self, session_serializer: URLSafeSerializer, *args, **kwargs): + def __init__(self, session_serializer: URLSafeSerializer, ignore_redirect_routes: Optional[List[str]] = None, *args, **kwargs): """Initialise validator for session based authentication.""" self._session_serializer = session_serializer + self._ignore_redirect_routes = ignore_redirect_routes super().__init__(*args, **kwargs) # type: ignore def get_state_from_session(self, request): @@ -36,8 +41,15 @@ def pop_post_auth_redirect(self, request): def set_post_auth_redirect(self, request, redirect='/'): """Set post-authentication redirects.""" + if not self.is_valid_redirect(redirect): + redirect = '/' + request.session[REDIRECT_KEY] = redirect + def is_valid_redirect(self, redirect): + """Check if the redirect is not to endpoints that we don't want to redirect to.""" + return not any(map(partial(fnmatch.fnmatch, redirect), self._ignore_redirect_routes)) + @staticmethod def get_session_serializer(secret, salt): """Get or Initialise the session serializer.""" diff --git a/src/fastapi_aad_auth/auth.py b/src/fastapi_aad_auth/auth.py index 6c66fa4..3d480ba 100644 --- a/src/fastapi_aad_auth/auth.py +++ b/src/fastapi_aad_auth/auth.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from starlette.authentication import requires -from starlette.middleware.authentication import AuthenticationError, AuthenticationMiddleware +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.requests import Request from starlette.responses import RedirectResponse, Response @@ -14,10 +14,10 @@ from fastapi_aad_auth._base.backend import BaseOAuthBackend from fastapi_aad_auth._base.validators import SessionValidator from fastapi_aad_auth.config import Config -from fastapi_aad_auth.errors import base_error_handler, ConfigurationError +from fastapi_aad_auth.errors import AuthenticationError, AuthorisationError, base_error_handler, ConfigurationError, json_error_handler, redirect_error_handler from fastapi_aad_auth.mixins import LoggingMixin from fastapi_aad_auth.ui.jinja import Jinja2Templates -from fastapi_aad_auth.utilities import deprecate +from fastapi_aad_auth.utilities import deprecate, is_interactive _BASE_ROUTES = ['openapi', 'swagger_ui_html', 'swagger_ui_redirect', 'redoc_html'] @@ -66,7 +66,7 @@ def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_ def _init_session_validator(self): auth_serializer = SessionValidator.get_session_serializer(self.config.auth_session.secret.get_secret_value(), self.config.auth_session.salt.get_secret_value()) - return SessionValidator(auth_serializer) + return SessionValidator(auth_serializer, ignore_redirect_routes=self.config.routing.no_redirect_routes) # Lets setup the oauth backend def _init_providers(self): @@ -114,13 +114,30 @@ async def configuration_error_handler(request: Request, exc: ConfigurationError) status_code = 500 return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) - @app.exception_handler(AuthenticationError) - async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response: + @app.exception_handler(AuthorisationError) + async def authorisation_error_handler(request: Request, exc: AuthorisationError) -> Response: error_message = "Oops! It seems like you cannot access this information. If this is an error, please contact an admin" - error_type = 'Authentication Error' + error_type = 'Authorisation Error' status_code = 403 return base_error_handler(request, exc, error_type, error_message, error_templates, error_template_path, context=self._base_context.copy(), status_code=status_code) + @app.exception_handler(AuthenticationError) + async def authentication_error_handler(request: Request, exc: AuthenticationError) -> Response: + return self._authentication_error_handler(request, exc) + + def _authentication_error_handler(self, request: Request, exc: AuthenticationError) -> Response: + error_message = "Oops! It seems like you are not correctly authenticated" + status_code = 401 + self.logger.exception(f'Error {exc} for request {request}') + if is_interactive(request): + self._session_validator.set_post_auth_redirect(request, request.url.path) + kwargs = {} + if self._session_validator.is_valid_redirect(request.url.path): + kwargs['redirect'] = request.url.path + return redirect_error_handler(self.config.routing.landing_path, exc, **kwargs) + else: + return json_error_handler(error_message, status_code=status_code) + def auth_required(self, scopes: str = 'authenticated', redirect: str = 'login'): """Decorator to require specific scopes (and redirect to the login ui) for an endpoint. @@ -186,10 +203,8 @@ def configure_app(self, app: FastAPI, add_error_handlers=True): Keyword Args: add_error_handlers (bool) : add the error handlers to the app (default is true, but can be set to False to configure specific handling) """ - def on_auth_error(request: Request, exc: Exception): - self.logger.exception(f'Error {exc} for request {request}') - self._session_validator.set_post_auth_redirect(request, request.url.path) - return RedirectResponse(self.config.routing.landing_path) + def on_auth_error(request: Request, exc: AuthenticationError): + return self._authentication_error_handler(request, exc) app.add_middleware(AuthenticationMiddleware, backend=self. auth_backend, on_error=on_auth_error) if add_error_handlers: diff --git a/src/fastapi_aad_auth/config.py b/src/fastapi_aad_auth/config.py index 30c48a6..96b9dc2 100644 --- a/src/fastapi_aad_auth/config.py +++ b/src/fastapi_aad_auth/config.py @@ -48,6 +48,11 @@ def _validate_post_logout_path(cls, value, values): value = values.get('landing_path') return value + @property + def no_redirect_routes(self): + """Routes that we don't want to redirect to.""" + return [self.login_path, self.login_redirect_path, f'{self.oauth_base_route}/*'] + @expand_doc class LoginUIConfig(BaseSettings): diff --git a/src/fastapi_aad_auth/errors.py b/src/fastapi_aad_auth/errors.py index 3139b7c..5131618 100644 --- a/src/fastapi_aad_auth/errors.py +++ b/src/fastapi_aad_auth/errors.py @@ -1,39 +1,74 @@ """fastapi_aad_auth errors.""" -from starlette.responses import JSONResponse, Response +from pathlib import Path +from typing import Dict, Optional -from fastapi_aad_auth.utilities import is_interactive +from starlette.authentication import AuthenticationError +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.templating import Jinja2Templates + +from fastapi_aad_auth.utilities import is_interactive, urls from fastapi_aad_auth.utilities.logging import getLogger logger = getLogger(__name__) -def base_error_handler(request, exception, error_type, error_message, templates, template_path, context=None, status_code=500) -> Response: +def base_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response: """Handle Error as JSON or HTML response depending on request type.""" - if context is None: - context = {} logger.warning(f'Handling error {exception}') if is_interactive(request): - logger.info('Interactive environment so returning error template') - logger.debug(f'Path: {template_path}') - error_context = context.copy() - error_context.update({'error': str(exception), - 'status_code': str(status_code), - 'error_type': error_type, - 'error_description': error_message, - 'request': request}) # type: ignore - response = templates.TemplateResponse(template_path.name, - error_context, - status_code=status_code) + response = ui_error_handler(request, exception, error_type, error_message, templates, template_path, context, status_code) else: - logger.info('Non-Interactive environment so returning JSON message') - - response = JSONResponse( # type: ignore - status_code=status_code, - content={"message": error_message} - ) + response = json_error_handler(error_message, status_code) logger.debug(f'Response {response}') return response +def json_error_handler(error_message: str, status_code: int = 500) -> JSONResponse: + """Handle error as a JSON.""" + logger.info('Non-Interactive environment so returning JSON message') + + return JSONResponse( # type: ignore + status_code=status_code, + content={"message": error_message} + ) + + +def redirect_error_handler(redirect_path: str, exception: Exception, **kwargs) -> RedirectResponse: + """Handle error as a redirect with error info in the query parameters.""" + return RedirectResponse(urls.with_query_params(redirect_path, error=exception, **kwargs)) + + +def ui_error_handler(request: Request, exception: Exception, error_type: str, error_message: str, templates: Jinja2Templates, template_path: Path, context: Optional[Dict] = None, status_code: int = 500) -> Response: + """Return a UI view of the error.""" + logger.info('Interactive environment so returning error template') + logger.debug(f'Path: {template_path}') + logger.debug(f'Exception: {exception}') + if context is None: + context = {} + error_context = context.copy() + error = exception + detail = '' + if exception.args: + logger.info('Getting args') + error = exception.args[0] + if len(exception.args) > 1: + detail = exception.args[1] + error_context.update({'error': str(error), + 'status_code': str(status_code), + 'error_type': error_type, + 'error_description': error_message, + 'error_detail': str(detail), + 'request': request}) # type: ignore + logger.debug(f'Error context: {error_context}') + return templates.TemplateResponse(template_path.name, + error_context, + status_code=status_code) + + class ConfigurationError(Exception): """Misconfigured application.""" + + +class AuthorisationError(AuthenticationError): + """Not Authorised to access this resource.""" diff --git a/src/fastapi_aad_auth/providers/aad.py b/src/fastapi_aad_auth/providers/aad.py index 3d2f9da..653c075 100644 --- a/src/fastapi_aad_auth/providers/aad.py +++ b/src/fastapi_aad_auth/providers/aad.py @@ -1,7 +1,8 @@ """AAD OAuth handlers.""" import base64 -from typing import List, Optional +from enum import Enum +from typing import List, Optional, Union from authlib.jose import errors as jwt_errors, jwk, jwt from authlib.jose.util import extract_header @@ -21,6 +22,12 @@ from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, expand_doc, is_deprecated, list_from_env, urls +class TokenType(Enum): + """Type of token to use.""" + access = 'access_token' + id = 'id_token' + + class BaseSettings(DeprecatableFieldsMixin, _BaseSettings): """Base Settings with Deprecatable Fields.""" pass @@ -40,7 +47,8 @@ def __init__( client_secret=None, scopes=None, redirect_uri=None, - domain_hint=None): + domain_hint=None, + token_type=TokenType.access): """Initialise AAD Authenticator for interactive (UI) sessions.""" super().__init__(session_validator, token_validator) self._redirect_path = redirect_path @@ -49,10 +57,12 @@ def __init__( self._prompt = prompt self.client_id = client_id if scopes is None: - scopes = [f'api://{self.client_id}'] - elif isinstance(scopes, str): - scopes = [scopes] + scopes = [f'{self.client_id}/openid'] + self.logger.info(f'Scopes {scopes}') self._scopes = scopes + if isinstance(token_type, Enum): + token_type = token_type.value + self._token_type = token_type self._authority = f'https://login.microsoftonline.com/{tenant_id}' if client_secret is not None: @@ -80,12 +90,15 @@ def _build_redirect_uri(self, request: Request): def _process_code(self, request: Request, auth_state, code): # Let's build up the redirect_uri - result = self.msal_application.acquire_token_by_authorization_code(code, scopes=[], + result = self.msal_application.acquire_token_by_authorization_code(code, scopes=self._scopes, redirect_uri=self._build_redirect_uri(request)) self.logger.debug(f'Result {result}') if 'error' in result and result['error']: - raise ConfigurationError(result) - return result['id_token'] + error_args = [result['error']] + if 'error_description' in result: + error_args.append(result['error_description']) + raise ConfigurationError(*error_args) + return result[self._token_type] def _get_user_from_token(self, token, options=None): if options is None: @@ -95,32 +108,44 @@ def _get_user_from_token(self, token, options=None): return super()._get_user_from_token(token, options=options) def _get_authorization_url(self, request, session_state): - return self.msal_application.get_authorization_request_url([], + return self.msal_application.get_authorization_request_url(self._scopes, state=session_state, claims_challenge='{"id_token": {"roles": {"essential": true} } }', redirect_uri=self._build_redirect_uri(request), prompt=self._prompt, domain_hint=self._domain_hint) - def get_access_token(self, user): + def get_access_token(self, user, scopes=None, app_scopes=True): """Get the access token for the user.""" result = None account = None + if scopes is None: + scopes = self._scopes + elif app_scopes: + scopes = self.as_app_scopes(scopes) if user.username: account = self.msal_application.get_accounts(user.username) if account: account = account[0] self.logger.info(account) - # This needs you to register the openid api - result = self.msal_application.acquire_token_silent_with_error(scopes=[f'api://{self.client_id}/openid'], account=account) - self.logger.info(result) + # This needs you to register the scopes in the app registration + result = self.msal_application.acquire_token_silent_with_error(scopes=scopes, account=account) + self.logger.info(f'Acquired Token: {result}') if result is None: raise ValueError('Token not found') + elif 'error' in result: + raise ConfigurationError(result['error'], result['error_description']) else: return {'token_type': result['token_type'], 'expires_in': result['expires_in'], 'access_token': result['access_token']} + def as_app_scopes(self, scopes): + """Add the application client id to the scopes so that the tokens are valid for this app.""" + if self.client_id not in scopes[0]: + scopes[0] = f'{self.client_id}/{scopes[0]}' + return scopes + class AADTokenValidator(TokenValidator): """Validator for AAD token based authentication.""" @@ -169,12 +194,13 @@ def _decode_token(self, token): self.logger.debug(f'Key is {jwk_}') try: if hasattr(jwk, 'public_bytes'): - public_bytes = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + key = jwk_.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) else: - public_bytes = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + key = jwk_.raw_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.PKCS1) + self.logger.debug(f'Processed Key: {key}') claims = jwt.decode( token, - public_bytes, + key, ) except Exception: self.logger.exception('Unable to parse error') @@ -186,7 +212,6 @@ def _validate_claims(self, claims, options=None): options = self._claims_options # We need to do some 1.0/2.0 handling because it doesn't seem to work properly # TODO: validate whether we want this claim here? - # TODO: validate whether the user is approved for the app if 'appid' in options and 'azp' in options: if 'appid' not in claims: options.pop('appid') @@ -217,10 +242,10 @@ def _get_user_from_claims(self, claims): username_key = 'unique_name' if 'name' not in claims and 'appid' in claims: # This is an application/service principal - return self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None)) - + user = self._user_klass(name=claims['appid'], email='', username=claims['appid'], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None)) else: - return self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None)) + user = self._user_klass(name=claims['name'], email=claims[username_key], username=claims[username_key], groups=claims.get('groups', None), roles=claims.get('roles', None), scopes=claims.get('scp', None)) + return user class AADProvider(Provider): @@ -242,7 +267,8 @@ def __init__( redirect_uri: Optional[str] = None, domain_hint: Optional[str] = None, user_klass: type = User, - oauth_base_route: str = '/oauth'): + oauth_base_route: str = '/oauth', + token_type: Union[str, TokenType] = TokenType.access): """Initialise the auth backend. Args: @@ -269,7 +295,7 @@ def __init__( session_authenticator = AADSessionAuthenticator(session_validator=session_validator, token_validator=token_validator, client_id=client_id, tenant_id=tenant_id, redirect_path=redirect_path, prompt=prompt, client_secret=client_secret, scopes=scopes, - redirect_uri=redirect_uri, domain_hint=domain_hint) + redirect_uri=redirect_uri, domain_hint=domain_hint, token_type=token_type) super().__init__(validators=[token_validator], authenticator=session_authenticator, enabled=enabled, oauth_base_route=oauth_base_route) @classmethod @@ -296,7 +322,8 @@ def from_config(cls, session_validator, config, provider_config, user_klass: Opt scopes=provider_config.scopes, client_app_ids=provider_config.client_app_ids, strict_token=provider_config.strict, api_audience=provider_config.api_audience, prompt=provider_config.prompt, domain_hint=provider_config.domain_hint, - redirect_uri=provider_config.redirect_uri, user_klass=user_klass, oauth_base_route=config.routing.oauth_base_route) + redirect_uri=provider_config.redirect_uri, user_klass=user_klass, + oauth_base_route=config.routing.oauth_base_route, token_type=provider_config.token_type) # We need to override the login and redirect etc until it is deprecated if hasattr(config.routing, 'login_path') and config.routing.login_path and not is_deprecated(config.routing.__fields__['login_path']): obj._login_url = config.routing.login_path @@ -326,7 +353,7 @@ class AADConfig(BaseSettings): client_id: SecretStr = Field(..., description="Application Registration Client ID", env='AAD_CLIENT_ID') tenant_id: SecretStr = Field(..., description="Application Registration Tenant ID", env='AAD_TENANT_ID') client_secret: Optional[SecretStr] = Field(None, description="Application Registration Client Secret (if required)", env='AAD_CLIENT_SECRET') - scopes: List[str] = Field(["Read"], description="Additional scopes requested") + scopes: Optional[List[str]] = Field(None, description="Additional scopes requested - if the scope is not configured to the application this will throw an error when validating the token") client_app_ids: Optional[List[str]] = Field(None, description="Additional Client App IDs to accept tokens from (when running as a backend service)", env='AAD_CLIENT_APP_IDS') strict: bool = Field(True, description="Check that all claims are provided", env='AAD_STRICT_CLAIM_CHECK') @@ -337,6 +364,7 @@ class AADConfig(BaseSettings): prompt: Optional[str] = Field(None, description="AAD prompt to request", env='AAD_PROMPT') domain_hint: Optional[str] = Field(None, description="AAD domain hint", env='AAD_DOMAIN_HINT') roles: Optional[List[str]] = Field(None, description="AAD roles required in claims", env='AAD_ROLES') + token_type: TokenType = Field(TokenType.access, description='The AAD token type to use to validate (we should use the access token if it is configured, unless we are acting as a pure UI component') _provider_klass: type = PrivateAttr(AADProvider) class Config: # noqa D106 diff --git a/src/fastapi_aad_auth/ui/__init__.py b/src/fastapi_aad_auth/ui/__init__.py index 93e52dc..76c02ef 100644 --- a/src/fastapi_aad_auth/ui/__init__.py +++ b/src/fastapi_aad_auth/ui/__init__.py @@ -8,7 +8,7 @@ * ``user.html``: View the user's information and get an access token """ from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Optional from fastapi import Depends from starlette.requests import Request @@ -21,6 +21,7 @@ from fastapi_aad_auth._base.state import AuthenticationState from fastapi_aad_auth.mixins import LoggingMixin from fastapi_aad_auth.ui.jinja import Jinja2Templates +from fastapi_aad_auth.utilities import urls class UI(LoggingMixin): @@ -81,14 +82,18 @@ def _get_user(self, request: Request, **kwargs): context['token_api_path'] = None # type: ignore return self.user_templates.TemplateResponse(self.user_template_path.name, context) - def _get_token(self, request: Request, auth_state: AuthenticationState): + def _get_token(self, request: Request, auth_state: AuthenticationState, scopes: Optional[List[str]] = None): """Return the access token for the user.""" if not isinstance(auth_state, AuthenticationState): user = self.__get_user_from_request(request) else: user = auth_state.user if hasattr(user, 'username'): # type: ignore - access_token = self.__get_access_token(user) + if scopes is None: + scopes = request.query_params.get('scopes', None) + if isinstance(scopes, str): + scopes = scopes.split(' ') # type: ignore + access_token = self.__get_access_token(user, scopes) if access_token: # We want to get the token for each provider that is authenticated return JSONResponse(access_token) # type: ignore @@ -98,7 +103,11 @@ def _get_token(self, request: Request, auth_state: AuthenticationState): return self.__force_authenticate(request) else: return JSONResponse('Unable to access token as user has not authenticated via session') - return RedirectResponse(f'{self.config.routing.landing_path}?redirect=/me/token') + redirect = '/me/token' + if scopes: + self.logger.debug(f'Getting Access Token with scopes {scopes}') + redirect = urls.with_query_params(redirect, scopes=scopes) + return RedirectResponse(urls.with_query_params(self.config.routing.landing_path, redirect=redirect)) @property def routes(self): @@ -120,8 +129,8 @@ async def login(request: Request, *args, **kwargs): async def get_user(request: Request): return self._get_user(request) - async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend.requires_auth(allow_session=True))): - return self._get_token(request, auth_state) + async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend.requires_auth(allow_session=True)), scopes: Optional[List[str]] = None): + return self._get_token(request, auth_state, scopes) routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'), Route(f'{self.config.routing.user_path}/token', endpoint=get_token, methods=['GET'], name='get-token')] @@ -129,16 +138,20 @@ async def get_token(request: Request, auth_state: AuthenticationState = Depends( return routes def __force_authenticate(self, request: Request): + # lets get the full redirect including any query parameters + redirect = urls.with_query_params(request.url.path, **request.query_params) + self.logger.debug(f'Request {request.url}') + self.logger.info(f'Forcing authentication with redirect = {redirect}') if len(self._authenticator._providers) == 1: - return self._authenticator._providers[0].authenticator.process_login_request(request, force=True, redirect=request.url.path) + return self._authenticator._providers[0].authenticator.process_login_request(request, force=True, redirect=redirect) else: - return RedirectResponse(f'{self.config.routing.landing_path}?redirect={request.url.path}') + return RedirectResponse(urls.with_query_params(self.config.routing.landing_path, redirect=redirect)) - def __get_access_token(self, user): + def __get_access_token(self, user, scopes=None): access_token = None for provider in self._authenticator._providers: try: - access_token = provider.authenticator.get_access_token(user) + access_token = provider.authenticator.get_access_token(user, scopes) except ValueError: pass if access_token is not None: diff --git a/src/fastapi_aad_auth/ui/error.html b/src/fastapi_aad_auth/ui/error.html index 15b9641..f7ca418 100644 --- a/src/fastapi_aad_auth/ui/error.html +++ b/src/fastapi_aad_auth/ui/error.html @@ -21,7 +21,11 @@

{{error_type}}

{{ error | safe }} - + {% if error_detail %} +
+ {{ error_detail }} + {% endif %} + {% endblock Error %} {% block ContentFooter %} diff --git a/src/fastapi_aad_auth/utilities/urls.py b/src/fastapi_aad_auth/utilities/urls.py index b0f95ab..c1369af 100644 --- a/src/fastapi_aad_auth/utilities/urls.py +++ b/src/fastapi_aad_auth/utilities/urls.py @@ -1,10 +1,16 @@ """URL utilities.""" +import logging + +from starlette.datastructures import URL + + +logger = logging.getLogger(__name__) def with_redirect(url, post_redirect=None): """Append a redirect query parameter.""" if post_redirect is not None: - url = f'{url}?redirect={post_redirect}' + url = with_query_params(url, redirect=post_redirect) return url @@ -16,3 +22,14 @@ def append(base_url, *args): else: url = base_url return url + + +def with_query_params(url, **query_params): + """Add query parameters to a url.""" + logger.debug(f'Adding {query_params} to {url}') + parsed_url = URL(url) + logger.debug(f'Existing query params {parsed_url.query}') + new_url = parsed_url.include_query_params(**query_params) + logger.debug(f'Updated query params {new_url.query}') + logger.debug(f'Updated url {new_url}') + return str(new_url)