diff --git a/src/fastapi_aad_auth/_base/backend.py b/src/fastapi_aad_auth/_base/backend.py index 1268037..9669e31 100644 --- a/src/fastapi_aad_auth/_base/backend.py +++ b/src/fastapi_aad_auth/_base/backend.py @@ -48,7 +48,7 @@ def check(self, request: Request, allow_session=True) -> Optional[Authentication continue state = validator.check(request) self.logger.debug(f'Authentication state {state} from validator {validator}') - if state is not None: + if state is not None and state.is_authenticated(): break self.logger.info(f'Identified state {state}') return state diff --git a/src/fastapi_aad_auth/_base/validators/token.py b/src/fastapi_aad_auth/_base/validators/token.py index 5229a4e..a531e71 100644 --- a/src/fastapi_aad_auth/_base/validators/token.py +++ b/src/fastapi_aad_auth/_base/validators/token.py @@ -45,12 +45,16 @@ def __init__( def check(self, request: Request): """Check the authentication from the request.""" - token = self.get_token(request) - if token is None: - return AuthenticationState.as_unauthenticated(None, None) - claims = self.validate_token(token) - user = self._get_user_from_claims(claims) - return AuthenticationState.authenticate_as(user, None, None) + state = AuthenticationState.as_unauthenticated(None, None) + try: + token = self.get_token(request) + if token is not None: + claims = self.validate_token(token) + user = self._get_user_from_claims(claims) + state = AuthenticationState.authenticate_as(user, None, None) + except Exception: + self.logger.exception('Error authenticating via token') + return state def get_token(self, request: Request): """Get the token from the request.""" diff --git a/src/fastapi_aad_auth/errors.py b/src/fastapi_aad_auth/errors.py index b57558c..3139b7c 100644 --- a/src/fastapi_aad_auth/errors.py +++ b/src/fastapi_aad_auth/errors.py @@ -1,7 +1,7 @@ """fastapi_aad_auth errors.""" from starlette.responses import JSONResponse, Response - +from fastapi_aad_auth.utilities import is_interactive from fastapi_aad_auth.utilities.logging import getLogger logger = getLogger(__name__) @@ -12,8 +12,7 @@ def base_error_handler(request, exception, error_type, error_message, templates, if context is None: context = {} logger.warning(f'Handling error {exception}') - status_code = 500 - if any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]): + if is_interactive(request): logger.info('Interactive environment so returning error template') logger.debug(f'Path: {template_path}') error_context = context.copy() diff --git a/src/fastapi_aad_auth/utilities/__init__.py b/src/fastapi_aad_auth/utilities/__init__.py index 71830fb..bd262df 100644 --- a/src/fastapi_aad_auth/utilities/__init__.py +++ b/src/fastapi_aad_auth/utilities/__init__.py @@ -1,26 +1,33 @@ """Utilities.""" import importlib +from starlette.requests import Request + from fastapi_aad_auth.utilities import logging # noqa: F401 from fastapi_aad_auth.utilities import urls # noqa: F401 from fastapi_aad_auth.utilities.deprecate import DeprecatableFieldsMixin, deprecate, deprecate_module, DeprecatedField, is_deprecated # noqa: F401 -def bool_from_env(env_value): +def is_interactive(request: Request): + """Check if a request is from an interactive client.""" + return any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]) + + +def bool_from_env(env_value: str): """Convert environment variable to boolean.""" if isinstance(env_value, str): env_value = env_value.lower() in ['true', '1'] return env_value -def list_from_env(env_value): +def list_from_env(env_value: str): """Convert environment variable to list.""" if isinstance(env_value, str): env_value = [u for u in env_value.split(',') if u] return env_value -def klass_from_str(value): +def klass_from_str(value: str): """Convert an import path to a class.""" if isinstance(value, str): if ':' in value: @@ -34,7 +41,7 @@ def klass_from_str(value): return value -def expand_doc(klass): +def expand_doc(klass: type): """Expand pydantic model documentation to enable autodoc.""" docs = ['', '', 'Keyword Args:'] for name, field in klass.__fields__.items():