Skip to content

Commit

Permalink
check access token for group attribute as well; fix #43
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Oct 24, 2024
1 parent bedb8dc commit c2e0976
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
1 change: 1 addition & 0 deletions mlflow_oidc_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class AppConfig:
OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
OIDC_AUDIENCE = os.environ.get("OIDC_AUDIENCE", None)

# https://flask-session.readthedocs.io/en/latest/config.html
SESSION_TYPE = os.environ.get("SESSION_TYPE", "filesystem")
Expand Down
11 changes: 10 additions & 1 deletion mlflow_oidc_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@

from mlflow.server import app

import jwt

# Create the OAuth2 client
auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID"))
store = SqlAlchemyStore()
Expand Down Expand Up @@ -724,14 +726,21 @@ def callback():
is_admin = False
user_groups = []

decoded_access_token = jwt.decode(access_token, audience=AppConfig.get_property("OIDC_AUDIENCE"), options={"verify_signature": False})
app.logger.debug(f"{decoded_access_token}")

if AppConfig.get_property("OIDC_GROUP_DETECTION_PLUGIN"):
import importlib

user_groups = importlib.import_module(AppConfig.get_property("OIDC_GROUP_DETECTION_PLUGIN")).get_user_groups(
access_token
)
else:
user_groups = user_data.get(AppConfig.get_property("OIDC_GROUPS_ATTRIBUTE"), [])
attr = AppConfig.get_property("OIDC_GROUPS_ATTRIBUTE")
if attr in decoded_access_token:
user_groups = decoded_access_token[attr]
if attr in user_data:
user_groups = user_data[attr]

app.logger.debug(f"User groups: {user_groups}")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"Flask-Session>=0.7.0",
"gunicorn<24; platform_system != 'Windows'",
"alembic<2,!=1.10.0",
"pyjwt>=2.9.0,<=3.0.0"
]

[project.optional-dependencies]
Expand Down

0 comments on commit c2e0976

Please sign in to comment.