Skip to content

Commit

Permalink
Merge pull request #39 from djpugh/feature/multiple-token-providers
Browse files Browse the repository at this point in the history
  • Loading branch information
djpugh authored Dec 16, 2020
2 parents fc9c3de + 22d2061 commit de2b4a5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/fastapi_aad_auth/_base/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def check(self, request: Request, allow_session=True) -> Optional[Authentication
continue
state = validator.check(request)
self.logger.debug(f'Authentication state {state} from validator {validator}')
if state is not None:
if state is not None and state.is_authenticated():
break
self.logger.info(f'Identified state {state}')
return state
Expand Down
16 changes: 10 additions & 6 deletions src/fastapi_aad_auth/_base/validators/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def __init__(

def check(self, request: Request):
"""Check the authentication from the request."""
token = self.get_token(request)
if token is None:
return AuthenticationState.as_unauthenticated(None, None)
claims = self.validate_token(token)
user = self._get_user_from_claims(claims)
return AuthenticationState.authenticate_as(user, None, None)
state = AuthenticationState.as_unauthenticated(None, None)
try:
token = self.get_token(request)
if token is not None:
claims = self.validate_token(token)
user = self._get_user_from_claims(claims)
state = AuthenticationState.authenticate_as(user, None, None)
except Exception:
self.logger.exception('Error authenticating via token')
return state

def get_token(self, request: Request):
"""Get the token from the request."""
Expand Down
5 changes: 2 additions & 3 deletions src/fastapi_aad_auth/errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""fastapi_aad_auth errors."""
from starlette.responses import JSONResponse, Response


from fastapi_aad_auth.utilities import is_interactive
from fastapi_aad_auth.utilities.logging import getLogger

logger = getLogger(__name__)
Expand All @@ -12,8 +12,7 @@ def base_error_handler(request, exception, error_type, error_message, templates,
if context is None:
context = {}
logger.warning(f'Handling error {exception}')
status_code = 500
if any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']]):
if is_interactive(request):
logger.info('Interactive environment so returning error template')
logger.debug(f'Path: {template_path}')
error_context = context.copy()
Expand Down
19 changes: 14 additions & 5 deletions src/fastapi_aad_auth/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
"""Utilities."""
import importlib
from typing import List, Union

from pydantic.main import ModelMetaclass
from starlette.requests import Request

from fastapi_aad_auth.utilities import logging # noqa: F401
from fastapi_aad_auth.utilities import urls # noqa: F401
from fastapi_aad_auth.utilities.deprecate import DeprecatableFieldsMixin, deprecate, deprecate_module, DeprecatedField, is_deprecated # noqa: F401


def bool_from_env(env_value):
def is_interactive(request: Request):
"""Check if a request is from an interactive client."""
return any([u in request.headers['user-agent'] for u in ['Mozilla', 'Gecko', 'Trident', 'WebKit', 'Presto', 'Edge', 'Blink']])


def bool_from_env(env_value: Union[bool, str]) -> bool:
"""Convert environment variable to boolean."""
if isinstance(env_value, str):
env_value = env_value.lower() in ['true', '1']
return env_value


def list_from_env(env_value):
def list_from_env(env_value: Union[List[str], str]) -> List[str]:
"""Convert environment variable to list."""
if isinstance(env_value, str):
env_value = [u for u in env_value.split(',') if u]
return env_value


def klass_from_str(value):
def klass_from_str(value: str):
"""Convert an import path to a class."""
if isinstance(value, str):
if ':' in value:
Expand All @@ -34,10 +43,10 @@ def klass_from_str(value):
return value


def expand_doc(klass):
def expand_doc(klass: ModelMetaclass) -> ModelMetaclass:
"""Expand pydantic model documentation to enable autodoc."""
docs = ['', '', 'Keyword Args:']
for name, field in klass.__fields__.items():
for name, field in klass.__fields__.items(): # type: ignore
default_str = ''
if field.default:
default_str = f' [default: ``{field.default}``]'
Expand Down

0 comments on commit de2b4a5

Please sign in to comment.