Skip to content

Commit

Permalink
🗃️ [#94] Set up new/custom DB field to configure claims
Browse files Browse the repository at this point in the history
The field abstracts away the underlying ArrayField usage.

A data migration is included that copies the existing
configuration into the new format.
  • Loading branch information
sergei-maertens committed May 1, 2024
1 parent 21a910d commit 08802a3
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 9 deletions.
16 changes: 16 additions & 0 deletions mozilla_django_oidc_db/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from django_jsonform.models.fields import ArrayField


class ClaimField(ArrayField):
"""
A field to store a path to claims holding the desired value(s).
Each item is a segment in the path from the root to leaf for nested claims.
"""

def __init__(self, *args, **kwargs):
kwargs["base_field"] = models.CharField(_("claim path segment"), max_length=50)
super().__init__(*args, **kwargs)
91 changes: 91 additions & 0 deletions mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Generated by Django 4.2.9 on 2024-05-01 16:10

from django.conf import settings
from django.core.cache import caches
from django.db import migrations, models, transaction

import mozilla_django_oidc_db.fields
import mozilla_django_oidc_db.models


def flush_cache():
if not (cache_name := getattr(settings, "SOLO_CACHE", "")):
return
caches[cache_name].clear()


def forward(config) -> None:
config.new_username_claim = config.username_claim.split(".")
config.new_groups_claim = config.groups_claim.split(".")
config.claim_mapping = {
key: value.split(".") for key, value in config.claim_mapping.items()
}


def reverse(config) -> None:
config.username_claim = ".".join(config.new_username_claim)
config.groups_claim = ".".join(config.new_groups_claim)
config.claim_mapping = {
key: ".".join(value) for key, value in config.claim_mapping.items()
}


def action_factory(transformer):
def _run_python_action(apps, _) -> None:
OpenIDConnectConfig = apps.get_model(
"mozilla_django_oidc_db", "OpenIDConnectConfig"
)

# Solo model, so there's only ever one instance
config = OpenIDConnectConfig.objects.first()
if config is None:
return

transformer(config)

config.save()
transaction.on_commit(flush_cache)

return _run_python_action


copy_forward = action_factory(transformer=forward)
copy_reverse = action_factory(transformer=reverse)


class Migration(migrations.Migration):

dependencies = [
("mozilla_django_oidc_db", "0001_initial_to_v015"),
]

operations = [
migrations.AddField(
model_name="openidconnectconfig",
name="new_groups_claim",
field=mozilla_django_oidc_db.fields.ClaimField(
base_field=models.CharField(
max_length=50, verbose_name="claim path segment"
),
blank=True,
default=mozilla_django_oidc_db.models.get_default_groups_claim,
help_text="The name of the OIDC claim that holds the values to map to local user groups.",
size=None,
verbose_name="groups claim",
),
),
migrations.AddField(
model_name="openidconnectconfig",
name="new_username_claim",
field=mozilla_django_oidc_db.fields.ClaimField(
base_field=models.CharField(
max_length=50, verbose_name="claim path segment"
),
default=mozilla_django_oidc_db.models.get_default_username_claim,
help_text="The name of the OIDC claim that is used as the username",
size=None,
verbose_name="username claim",
),
),
migrations.RunPython(copy_forward, copy_reverse),
]
36 changes: 29 additions & 7 deletions mozilla_django_oidc_db/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict, List

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
Expand All @@ -14,30 +12,39 @@
import mozilla_django_oidc_db.settings as oidc_settings

from .compat import classproperty
from .fields import ClaimField


class UserInformationClaimsSources(models.TextChoices):
userinfo_endpoint = "userinfo_endpoint", _("Userinfo endpoint")
id_token = "id_token", _("ID token")


def get_default_scopes() -> List[str]:
def get_default_scopes() -> list[str]:
"""
Returns the default scopes to request for OpenID Connect logins
"""
return ["openid", "email", "profile"]


def get_claim_mapping() -> Dict[str, str]:
def get_claim_mapping() -> dict[str, list[str]]:
# Map (some) claim names from https://openid.net/specs/openid-connect-core-1_0.html#Claims
# to corresponding field names on the User model
return {
"email": "email",
"first_name": "given_name",
"last_name": "family_name",
"email": ["email"],
"first_name": ["given_name"],
"last_name": ["family_name"],
}


def get_default_username_claim() -> list[str]:
return ["sub"]


def get_default_groups_claim() -> list[str]:
return ["roles"]


class CachingMixin:
@classmethod
def clear_cache(cls):
Expand Down Expand Up @@ -254,6 +261,12 @@ class OpenIDConnectConfig(CachingMixin, OpenIDConnectConfigBase):
default="sub",
help_text=_("The name of the OIDC claim that is used as the username"),
)
new_username_claim = ClaimField(
verbose_name=_("username claim"),
default=get_default_username_claim,
help_text=_("The name of the OIDC claim that is used as the username"),
)

claim_mapping = models.JSONField(
_("claim mapping"),
default=get_claim_mapping,
Expand All @@ -268,6 +281,15 @@ class OpenIDConnectConfig(CachingMixin, OpenIDConnectConfigBase):
),
blank=True,
)
new_groups_claim = ClaimField(
verbose_name=_("groups claim"),
default=get_default_groups_claim,
help_text=_(
"The name of the OIDC claim that holds the values to map to local user groups."
),
blank=True,
)

sync_groups = models.BooleanField(
_("Create local user groups if they do not exist yet"),
default=True,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ def test_backend_create_user(mock_get_solo):
oidc_op_user_endpoint="http://some.endpoint/v1/user",
)

User = get_user_model()

claims = {
"sub": "123456",
"email": "admin@localhost",
Expand Down

0 comments on commit 08802a3

Please sign in to comment.