Skip to content

Commit

Permalink
use trusted header to get the user id instead of authentication; fix #69
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Jan 8, 2025
1 parent e706e86 commit b15b6ec
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
16 changes: 16 additions & 0 deletions mlflow_oidc_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.user import create_user
from mlflow.exceptions import MlflowException


_oauth_instance: Optional[OAuth] = None
Expand Down Expand Up @@ -77,3 +79,17 @@ def authenticate_request_bearer_token() -> Union[Authorization, Response]:
except Exception as e:
app.logger.debug("JWT auth failed")
return False


def login_with_trusted_header():
from mlflow_oidc_auth.app import app

email = request.headers.get(config.TRUSTED_USER_ID_HEADER)
if not email:
return False
try:
create_user(username=email, display_name=email, is_admin=False)
app.logger.debug("User %s logged in for first time -> created", email)
except MlflowException as e:
app.logger.debug(f"User {email} logged in for first time but could no be created")
return True
2 changes: 2 additions & 0 deletions mlflow_oidc_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self):
self.OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
self.OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
self.OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
self.USE_TRUSTED_USER_ID_HEADER = os.environ.get("USE_TRUSTED_USER_ID_HEADER", str(False)).lower() in ("true", "1", "t")
self.TRUSTED_USER_ID_HEADER = os.environ.get("TRUSTED_USER_ID_HEADER", None)

# session
self.SESSION_TYPE = os.environ.get("SESSION_TYPE", "cachelib")
Expand Down
25 changes: 15 additions & 10 deletions mlflow_oidc_auth/hooks/before_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

import mlflow_oidc_auth.responses as responses
from mlflow_oidc_auth import routes
from mlflow_oidc_auth.auth import authenticate_request_basic_auth, authenticate_request_bearer_token
from mlflow_oidc_auth.auth import authenticate_request_basic_auth, authenticate_request_bearer_token, login_with_trusted_header
from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.utils import get_is_admin
from mlflow_oidc_auth.validators import (
Expand Down Expand Up @@ -185,21 +185,26 @@ def before_request_hook():
the view function for the matched route is called and returns a response"""
if _is_unprotected_route(request.path):
return
if request.authorization is not None:

if config.USE_TRUSTED_USER_ID_HEADER:
if not login_with_trusted_header():
return responses.make_auth_required_response()
elif request.authorization is not None:
if request.authorization.type == "basic":
if not authenticate_request_basic_auth():
return responses.make_basic_auth_response()
if request.authorization.type == "bearer":
if not authenticate_request_bearer_token():
return responses.make_auth_required_response()
else:
if session.get("username") is None:
session.clear()
return render_template(
"auth.html",
username=None,
provide_display_name=config.OIDC_PROVIDER_DISPLAY_NAME,
)
else:
return responses.make_auth_required_response()
elif session.get("username") is None:
session.clear()
return render_template(
"auth.html",
username=None,
provide_display_name=config.OIDC_PROVIDER_DISPLAY_NAME,
)
# admins don't need to be authorized
if get_is_admin():
return
Expand Down
4 changes: 4 additions & 0 deletions mlflow_oidc_auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def get_username():
username = validate_token(request.authorization.token).get("email")
app.logger.debug(f"Username from bearer token: {username}")
return username
if config.TRUSTED_USER_ID_HEADER:
username = request.headers.get(config.TRUSTED_USER_ID_HEADER)
app.logger.debug(f"Username from trusted header {config.TRUSTED_USER_ID_HEADER}: {username}")
return username
return None


Expand Down

0 comments on commit b15b6ec

Please sign in to comment.