Skip to content

Commit

Permalink
Merge pull request #37 from djpugh/fix/optional-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
djpugh authored Dec 15, 2020
2 parents c9c79de + b8edcdc commit 5e7989d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 30 deletions.
48 changes: 28 additions & 20 deletions src/fastapi_aad_auth/_base/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
class BaseOAuthBackend(NotAuthenticatedMixin, LoggingMixin, AuthenticationBackend):
"""Base OAuthBackend with token and session validators."""

def __init__(self, validators: List[Validator]):
def __init__(self, validators: List[Validator], enabled: bool = True):
"""Initialise the validators."""
super().__init__()
self.enabled = enabled
self.validators = validators[:]

async def authenticate(self, request):
Expand Down Expand Up @@ -62,25 +63,32 @@ def requires_auth(self, allow_session: bool = False):
# This is a bit horrible, but is needed for fastapi to get this into OpenAPI (or similar) - it needs to be an OAuth2 object
# We create this here "dynamically" for each endpoint, as we allow customisation on whether a session is permissible

class OAuthValidator(OAuth2AuthorizationCodeBearer):
"""OAuthValidator for API Auth."""

def __init__(self_):
"""Initialise the validator."""
token_validators = [u for u in self.validators if isinstance(u, TokenValidator)]
super().__init__(authorizationUrl=token_validators[0].model.flows.authorizationCode.authorizationUrl,
tokenUrl=token_validators[0].model.flows.authorizationCode.tokenUrl,
scopes=token_validators[0].model.flows.authorizationCode.scopes,
refreshUrl=token_validators[0].model.flows.authorizationCode.refreshUrl)

async def __call__(self_, request: Request):
"""Validate a request."""
state = self.check(request, allow_session)
if state is None or not state.is_authenticated():
raise self.not_authenticated
return state

return OAuthValidator()
if self.enabled:
class OAuthValidator(OAuth2AuthorizationCodeBearer):
"""OAuthValidator for API Auth."""

def __init__(self_):
"""Initialise the validator."""
token_validators = [u for u in self.validators if isinstance(u, TokenValidator)]
super().__init__(authorizationUrl=token_validators[0].model.flows.authorizationCode.authorizationUrl,
tokenUrl=token_validators[0].model.flows.authorizationCode.tokenUrl,
scopes=token_validators[0].model.flows.authorizationCode.scopes,
refreshUrl=token_validators[0].model.flows.authorizationCode.refreshUrl)

async def __call__(self_, request: Request):
"""Validate a request."""
state = self.check(request, allow_session)
if state is None or not state.is_authenticated():
raise self.not_authenticated
return state

return OAuthValidator()

else:
def noauth(request: Request):
return AuthenticationState()

return noauth

@property # type: ignore
@deprecate('0.2.0', replaced_by=f'{__name__}:BaseOAuthBackend.requires_auth')
Expand Down
11 changes: 6 additions & 5 deletions src/fastapi_aad_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, config: Config = None, add_to_base_routes: bool = True, base_
self._session_validator = self._init_session_validator()
self._providers = self._init_providers()
self.auth_backend = self._init_auth_backend()
self._ui
self._ui_routes = self._init_ui()
self._auth_routes = self._init_auth_routes()

Expand All @@ -75,11 +76,11 @@ def _init_auth_backend(self):
validators = [self._session_validator]
for provider in self._providers:
validators += provider.validators
return BaseOAuthBackend(validators)
return BaseOAuthBackend(validators, enabled=self.config.enabled)

def _init_ui(self):
ui = self.config.login_ui.ui_klass(self.config, self, self._base_context)
return ui.routes
self._ui = self.config.login_ui.ui_klass(self.config, self, self._base_context)
return self._ui.routes

def _init_auth_routes(self):

Expand Down Expand Up @@ -185,7 +186,6 @@ def configure_app(self, app: FastAPI, add_error_handlers=True):
Keyword Args:
add_error_handlers (bool) : add the error handlers to the app (default is true, but can be set to False to configure specific handling)
"""

def on_auth_error(request: Request, exc: Exception):
self.logger.exception(f'Error {exc} for request {request}')
self._session_validator.set_post_auth_redirect(request, request.url.path)
Expand All @@ -202,7 +202,8 @@ def on_auth_error(request: Request, exc: Exception):
app.routes.extend(self._ui_routes)
app.routes.extend(self._auth_routes)
# TODO: select a specific provider to use here
app.swagger_ui_init_oauth = self._providers[0].validators[0].init_oauth
if self.config.enabled:
app.swagger_ui_init_oauth = self._providers[0].validators[0].init_oauth


_DEPRECATED_VERSION = '0.2.0'
Expand Down
12 changes: 8 additions & 4 deletions src/fastapi_aad_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Config(BaseSettings):

enabled: bool = Field(True, description="Enable authentication", env='FASTAPI_AUTH_ENABLED')
providers: List[Union[AADConfig]] = Field(None, description="The provider configurations to use")
aad: AADConfig = DeprecatedField(None, description='AAD Configuration information', deprecated_in='0.2.0', replaced_by='Config.providers')
aad: Optional[AADConfig] = DeprecatedField(None, description='AAD Configuration information', deprecated_in='0.2.0', replaced_by='Config.providers')
auth_session: AuthSessionConfig = Field(None, description="The configuration for encoding the authentication information in the session")
routing: RoutingConfig = Field(None, description="Configuration for routing")
session: SessionConfig = Field(None, description="Configuration for the session middleware")
Expand All @@ -146,14 +146,18 @@ class Config: # noqa D106
env_file = '.env'

@validator('providers', always=True, pre=True)
def _validate_providers(cls, value):
def _validate_providers(cls, value, values):
enabled = values.get('enabled', cls.__fields__['enabled'].default)
if value is None:
value = [AADConfig(_env_file=cls.Config.env_file)]
value = []
if enabled:
value.append(AADConfig(_env_file=cls.Config.env_file))
return value

@validator('aad', always=True, pre=True)
def _validate_aad(cls, value, values):
if value is None:
enabled = values.get('enabled', cls.__fields__['enabled'].default)
if value is None and enabled:
providers = values.get('providers', [AADConfig(_env_file=cls.Config.env_file)])
value = [u for u in providers if isinstance(u, AADConfig)][0]
return value
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_aad_auth/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def login(request: Request, *args, **kwargs):
async def get_user(request: Request):
return self._get_user(request)

async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend)):
async def get_token(request: Request, auth_state: AuthenticationState = Depends(self._authenticator.auth_backend.requires_auth(allow_session=True))):
return self._get_token(request, auth_state)

routes += [Route(self.config.routing.user_path, endpoint=get_user, methods=['GET'], name='user'),
Expand Down

0 comments on commit 5e7989d

Please sign in to comment.