diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index 1121d78..0590b27 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -34,6 +34,11 @@ class AppConfig: OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None) OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None) OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None) + OIDC_AUDIENCE = os.environ.get("OIDC_AUDIENCE", None) + OIDC_PUBLIC_KEYS_URL = os.environ.get("OIDC_PUBLIC_KEYS_URL", None) + OIDC_USERNAME_TOKEN_ATTRIBUTE = os.environ.get("OIDC_USERNAME_TOKEN_ATTRIBUTE", "email") + OIDC_SIGNING_ALG = os.environ.get("OIDC_SIGNING_ALG", None) # ES256, EdDSA, PS256, RS256, HS256 + OIDC_HS256_SECRET = os.environ.get("OIDC_HS256_SECRET", None) # ES256, EdDSA, PS256, RS256, HS256 # https://flask-session.readthedocs.io/en/latest/config.html SESSION_TYPE = os.environ.get("SESSION_TYPE", "filesystem") diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 88df955..ecacb8b 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -1,3 +1,4 @@ +import json import os import re import requests @@ -96,6 +97,8 @@ from mlflow.server import app +import jwt + # Create the OAuth2 client auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID")) store = SqlAlchemyStore() @@ -186,6 +189,9 @@ def _get_permission_from_store_or_default( def authenticate_request_basic_auth() -> Union[Authorization, Response]: username = request.authorization.username + if username == "" or username is None: + app.logger.debug("Username is not set in basic auth") + return False password = request.authorization.password app.logger.debug("Authenticating user %s", username) if store.authenticate_user(username.lower(), password): @@ -195,6 +201,64 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]: else: app.logger.debug("User %s not authenticated", username) return False + + +def _get_public_keys(): + """ + Returns: + List of RSA public keys usable by PyJWT. + """ + r = requests.get(AppConfig.get_property("OIDC_PUBLIC_KEYS_URL")) + public_keys = [] + jwk_set = r.json() + for key_dict in jwk_set["keys"]: + public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key_dict)) + public_keys.append(public_key) + return public_keys + + +def validate_token(token, key, sign_alg): + try: + token = jwt.decode(token, key=key, audience=AppConfig.get_property("OIDC_AUDIENCE"), algorithms=[sign_alg]) + except jwt.exceptions.InvalidTokenError as e: + app.logger.debug(f"Token is not valid: {token}") + raise MlflowException(f"Token is not valid: {str(e)}") + username_token_attr = AppConfig.get_property("OIDC_USERNAME_TOKEN_ATTRIBUTE") + username = token[username_token_attr] + if username == "" or username is None: + app.logger.debug(f"username from token attribute {username_token_attr} is {username}") + raise MlflowException(f"Username is not set at attribute: {username_token_attr}") + _set_username(username) + + +def authenticate_token(): + """ + Verify the token in the request. + """ + token = request.authorization.token + if token == "" or token is None: + app.logger.debug(f"Token is not set: {token}") + return False + sign_alg = AppConfig.get_property("OIDC_SIGNING_ALG") + token_is_valid = False + if sign_alg == "HS256": + key = AppConfig.get_property("OIDC_HS256_SECRET") + try: + validate_token(token, key, sign_alg) + token_is_valid = True + except MlflowException as e: + return False + + keys = _get_public_keys() + for key in keys: + try: + validate_token(token, key, sign_alg) + token_is_valid = True + break + except MlflowException as e: + return False + + return token_is_valid def _get_username(): @@ -617,8 +681,13 @@ def before_request_hook(): if _is_unprotected_route(request.path): return if request.authorization is not None: - if not authenticate_request_basic_auth(): - return make_basic_auth_response() + if not authenticate_token(): + app.logger.debug("No valid token authentication found") + # TODO maybe return 401 here instead of basic auth response + + if not authenticate_request_basic_auth(): + app.logger.debug("No valid basic authentication found") + return make_basic_auth_response() else: # authentication if not _get_username(): diff --git a/pyproject.toml b/pyproject.toml index c905976..6619248 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "Flask-Session>=0.7.0", "gunicorn<24; platform_system != 'Windows'", "alembic<2,!=1.10.0", + "pyjwt[crypto]>=2.9.0,<3.0.0" ] [project.optional-dependencies]