From f2c84fb83a726b4ffdc58a22d93cd1a25c23011d Mon Sep 17 00:00:00 2001 From: John Christoforidis Date: Mon, 25 Nov 2024 13:59:03 +0000 Subject: [PATCH 1/3] feat: Add custom parameter support to hooks for multi tenant use cases --- README.md | 1 + django_saml2_auth/saml.py | 20 +++++++++++------ django_saml2_auth/tests/test_saml.py | 32 +++++++++++++++++++++++++++- django_saml2_auth/utils.py | 15 ++++++++++--- django_saml2_auth/views.py | 11 +++++----- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 59bad0b..4de74b2 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,7 @@ Some of the following settings are related to how this module operates. The rest | **TRIGGER.CUSTOM\_CREATE\_JWT** | A hook function to create a custom JWT for the user. This method will be called instead of the `create_jwt_token` default function and should return the token. This method accepts one parameter: `user`. | `str` | `None` | `my_app.models.users.create_custom_token` | | **TRIGGER.CUSTOM\_TOKEN\_QUERY** | A hook function to create a custom query params with the JWT for the user. This method will be called after `CUSTOM_CREATE_JWT` to populate a query and attach it to a URL; should return the query params containing the token (e.g., `?token=encoded.jwt.token`). This method accepts one parameter: `token`. | `str` | `None` | `my_app.models.users.get_custom_token_query` | | **TRIGGER.GET\_CUSTOM\_ASSERTION\_URL** | A hook function to get the assertion URL dynamically. Useful when you have dynamic routing, multi-tenant setup and etc. Overrides `ASSERTION_URL`. | `str` | `None` | `my_app.utils.get_custom_assertion_url` | +| **TRIGGER.GET\_CUSTOM\_ENTITY\_ID** | A hook function to get the Entity ID dynamically. Useful when you have dynamic routing, multi-tenant setup and etc. Overrides `ENTITY_ID`. | `str` | `None` | `my_app.utils.get_custom_entity_id_url` | | **TRIGGER.GET\_CUSTOM\_FRONTEND\_URL** | A hook function to get a dynamic `FRONTEND_URL` dynamically (see below for more details). Overrides `FRONTEND_URL`. Acceots one parameter: `relay_state`. | `str` | `None` | `my_app.utils.get_custom_frontend_url` | | **ASSERTION\_URL** | A URL to validate incoming SAML responses against. By default, `django-saml2-auth` will validate the SAML response's Service Provider address against the actual HTTP request's host and scheme. If this value is set, it will validate against `ASSERTION_URL` instead - perfect for when Django is running behind a reverse proxy. This will only allow to customize the domain part of the URL, for more customization use `GET_CUSTOM_ASSERTION_URL`. | `str` | `None` | `https://example.com` | | **ENTITY\_ID** | The optional entity ID string to be passed in the 'Issuer' element of authentication request, if required by the IDP. | `str` | `None` | `https://exmaple.com/sso/acs` | diff --git a/django_saml2_auth/saml.py b/django_saml2_auth/saml.py index 62cf61b..856a61a 100644 --- a/django_saml2_auth/saml.py +++ b/django_saml2_auth/saml.py @@ -156,9 +156,9 @@ def get_metadata( ) -def get_custom_acs_url() -> Optional[str]: +def get_custom_acs_url(tenant_id: Optional[str] = None) -> Optional[str]: get_custom_acs_url_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ASSERTION_URL") - return run_hook(get_custom_acs_url_hook) if get_custom_acs_url_hook else None + return run_hook(get_custom_acs_url_hook, tenant_id=tenant_id) if get_custom_acs_url_hook else None def get_saml_client( @@ -166,6 +166,7 @@ def get_saml_client( acs: Callable[..., HttpResponse], user_id: Optional[str] = None, saml_response: Optional[str] = None, + tenant_id: Optional[str] = None, ) -> Optional[Saml2Client]: """Create a new Saml2Config object with the given config and return an initialized Saml2Client using the config object. The settings are read from django settings key: SAML2_AUTH. @@ -178,6 +179,7 @@ def get_saml_client( to the given user identifier, either email or username. Defaults to None. user_id (str or None): User identifier: username or email. Defaults to None. saml_response (str or None): decoded XML SAML response. + tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None. Raises: SAMLAuthError: Re-raise any exception raised by Saml2Config or Saml2Client @@ -206,7 +208,7 @@ def get_saml_client( }, ) - acs_url = get_custom_acs_url() + acs_url = get_custom_acs_url(tenant_id) if not acs_url: # get_reverse raises an exception if the view is not found, so we can safely ignore type errors acs_url = domain + get_reverse([acs, "acs", "django_saml2_auth:acs"]) # type: ignore @@ -246,6 +248,11 @@ def get_saml_client( if entity_id: saml_settings["entityid"] = entity_id + get_custom_entity_id_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ENTITY_ID") + + if get_custom_entity_id_hook: + saml_settings["entityid"] = run_hook(get_custom_entity_id_hook, tenant_id=tenant_id) + name_id_format = saml2_auth_settings.get("NAME_ID_FORMAT") if name_id_format: saml_settings["service"]["sp"]["name_id_policy_format"] = name_id_format @@ -298,7 +305,7 @@ def get_saml_client( def decode_saml_response( - request: HttpRequest, acs: Callable[..., HttpResponse] + request: HttpRequest, acs: Callable[..., HttpResponse], tenant_id: Optional[str] = None ) -> Union[HttpResponseRedirect, Optional[AuthnResponse], None]: """Given a request, the authentication response inside the SAML response body is parsed, decoded and returned. If there are any issues parsing the request, the identity or the issuer, @@ -307,6 +314,7 @@ def decode_saml_response( Args: request (HttpRequest): Django request object from identity provider (IdP) acs (Callable[..., HttpResponse]): The acs endpoint + tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None. Raises: SAMLAuthError: There was no response from SAML client. @@ -335,7 +343,7 @@ def decode_saml_response( saml_response = base64.b64decode(response).decode("UTF-8") except Exception: saml_response = None - saml_client = get_saml_client(get_assertion_url(request), acs, saml_response=saml_response) + saml_client = get_saml_client(get_assertion_url(request), acs, saml_response=saml_response, tenant_id=tenant_id) if not saml_client: raise SAMLAuthError( "There was an error creating the SAML client.", @@ -478,4 +486,4 @@ def extract_user_identity( return run_hook(extract_user_identity_trigger, user, authn_response) # type: ignore # If there is no custom trigger, the user identity is returned as is. - return user + return user \ No newline at end of file diff --git a/django_saml2_auth/tests/test_saml.py b/django_saml2_auth/tests/test_saml.py index ce8f813..b6e19f4 100644 --- a/django_saml2_auth/tests/test_saml.py +++ b/django_saml2_auth/tests/test_saml.py @@ -118,9 +118,15 @@ def get_metadata_auto_conf_urls( def get_custom_assertion_url(): return "https://example.com/custom-tenant/acs" +def get_custom_assertion_url_with_param(tenant_id: str): + return f"https://example.com/{tenant_id}/acs" -GET_CUSTOM_ASSERTION_URL = "django_saml2_auth.tests.test_saml.get_custom_assertion_url" +def get_custom_entity_id_url(tenant_id: str): + return f"https://example.com/sso/{tenant_id}" +GET_CUSTOM_ASSERTION_URL = "django_saml2_auth.tests.test_saml.get_custom_assertion_url" +GET_CUSTOM_ASSERTION_URL_WITH_PARAM = "django_saml2_auth.tests.test_saml.get_custom_assertion_url_with_param" +GET_CUSTOM_ENTITY_ID = "django_saml2_auth.tests.test_saml.get_custom_entity_id_url" def mock_extract_user_identity( user: Dict[str, Optional[Any]], authn_response: AuthnResponse @@ -479,6 +485,30 @@ def test_get_saml_client_success_with_custom_assertion_url_hook(settings: Settin "sp", ) +def test_get_saml_client_success_with_custom_assertion_url_and_param_hook(settings: SettingsWrapper): + settings.SAML2_AUTH["METADATA_LOCAL_FILE_PATH"] = "django_saml2_auth/tests/metadata.xml" + + settings.SAML2_AUTH["TRIGGER"]["GET_CUSTOM_ASSERTION_URL"] = GET_CUSTOM_ASSERTION_URL_WITH_PARAM + + result = get_saml_client("example.com", acs, "test@example.com", tenant_id="custom-tenant") + assert result is not None + assert "https://example.com/custom-tenant/acs" in result.config.endpoint( + "assertion_consumer_service", + BINDING_HTTP_POST, + "sp", + ) + + +def test_get_saml_client_success_with_custom_entity_id_hook(settings: SettingsWrapper): + settings.SAML2_AUTH["METADATA_LOCAL_FILE_PATH"] = "django_saml2_auth/tests/metadata.xml" + + settings.SAML2_AUTH["TRIGGER"]["GET_CUSTOM_ENTITY_ID"] = GET_CUSTOM_ENTITY_ID + + result = get_saml_client("example.com", acs, "test@example.com", tenant_id="custom-tenant") + assert result is not None + assert "https://example.com/sso/custom-tenant" == result.config.entityid + + @responses.activate def test_decode_saml_response_success( settings: SettingsWrapper, diff --git a/django_saml2_auth/utils.py b/django_saml2_auth/utils.py index ae7f6fe..5700989 100644 --- a/django_saml2_auth/utils.py +++ b/django_saml2_auth/utils.py @@ -91,7 +91,13 @@ def run_hook( }, ) try: - result = getattr(cls, path[-1])(*args, **kwargs) + func: Callable = getattr(cls, path[-1]) + + if func.__code__.co_argcount > 0: # backwards compatibility with existing hooks with no parameters + result = func(*args, **kwargs) + else: + result = func() + except SAMLAuthError as exc: # Re-raise the exception raise exc @@ -196,18 +202,21 @@ def handle_exception(exc: Exception, request: HttpRequest) -> HttpResponse: return render(request, "django_saml2_auth/error.html", context=context, status=status) @wraps(function) - def wrapper(request: HttpRequest) -> HttpResponse: + def wrapper(request: HttpRequest, **kwargs: Optional[Mapping[str, Any]]) -> HttpResponse: """Decorated function is wrapped and called here Args: request ([type]): [description] + Keyword Args: + **kwargs: Additional keyword arguments + Returns: HttpResponse: Either a redirect or a response with error details """ result = None try: - result = function(request) + result = function(request, **kwargs) except (SAMLAuthError, Exception) as exc: result = handle_exception(exc, request) return result diff --git a/django_saml2_auth/views.py b/django_saml2_auth/views.py index b1ad374..713e2f9 100644 --- a/django_saml2_auth/views.py +++ b/django_saml2_auth/views.py @@ -85,7 +85,7 @@ def denied(request: HttpRequest) -> HttpResponse: @csrf_exempt @exception_handler -def acs(request: HttpRequest): +def acs(request: HttpRequest, tenant_id: Optional[str] = None): """Assertion Consumer Service is SAML terminology for the location at a ServiceProvider that accepts messages (or SAML artifacts) for the purpose of establishing a session based on an assertion. Assertion is a signed authentication request from identity provider (IdP) @@ -93,7 +93,7 @@ def acs(request: HttpRequest): Args: request (HttpRequest): Incoming request from identity provider (IdP) for authentication - + tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None. Exceptions: SAMLAuthError: The target user is inactive. @@ -106,7 +106,7 @@ def acs(request: HttpRequest): """ saml2_auth_settings = settings.SAML2_AUTH - authn_response = decode_saml_response(request, acs) + authn_response = decode_saml_response(request, acs, tenant_id=tenant_id) # decode_saml_response() will raise SAMLAuthError if the response is invalid, # so we can safely ignore the type check here. user = extract_user_identity(authn_response) # type: ignore @@ -263,12 +263,13 @@ def sp_initiated_login(request: HttpRequest) -> HttpResponseRedirect: @exception_handler -def signin(request: HttpRequest) -> HttpResponseRedirect: +def signin(request: HttpRequest, tenant_id: Optional[str] = None) -> HttpResponseRedirect: """Custom sign-in view for SP-initiated SSO. This will be deprecated in the future in favor of sp_initiated_login. Args: request (HttpRequest): Incoming request from service provider (SP) for authentication. + tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None. Raises: SAMLAuthError: The next URL is invalid. @@ -306,7 +307,7 @@ def signin(request: HttpRequest) -> HttpResponseRedirect: request.session["login_next_url"] = next_url - saml_client = get_saml_client(get_assertion_url(request), acs) + saml_client = get_saml_client(get_assertion_url(request), acs, tenant_id=tenant_id) _, info = saml_client.prepare_for_authenticate(relay_state=next_url) # type: ignore redirect_url = dict(info["headers"]).get("Location", "") From 688254b5d2aaf4d58da2f58699b37870ff4807ad Mon Sep 17 00:00:00 2001 From: j0x539 Date: Mon, 25 Nov 2024 16:45:55 +0000 Subject: [PATCH 2/3] Add newline to end of file --- django_saml2_auth/saml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_saml2_auth/saml.py b/django_saml2_auth/saml.py index 856a61a..20ab3ac 100644 --- a/django_saml2_auth/saml.py +++ b/django_saml2_auth/saml.py @@ -486,4 +486,4 @@ def extract_user_identity( return run_hook(extract_user_identity_trigger, user, authn_response) # type: ignore # If there is no custom trigger, the user identity is returned as is. - return user \ No newline at end of file + return user From a3581fe68f9f5ab6702b0927aea74a5585cf5366 Mon Sep 17 00:00:00 2001 From: j0x539 Date: Tue, 26 Nov 2024 10:00:47 +0000 Subject: [PATCH 3/3] Suppress mypy warning for run_hook --- django_saml2_auth/saml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/django_saml2_auth/saml.py b/django_saml2_auth/saml.py index 20ab3ac..c9448d6 100644 --- a/django_saml2_auth/saml.py +++ b/django_saml2_auth/saml.py @@ -158,7 +158,7 @@ def get_metadata( def get_custom_acs_url(tenant_id: Optional[str] = None) -> Optional[str]: get_custom_acs_url_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ASSERTION_URL") - return run_hook(get_custom_acs_url_hook, tenant_id=tenant_id) if get_custom_acs_url_hook else None + return run_hook(get_custom_acs_url_hook, tenant_id=tenant_id) if get_custom_acs_url_hook else None # type: ignore def get_saml_client( @@ -251,7 +251,7 @@ def get_saml_client( get_custom_entity_id_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ENTITY_ID") if get_custom_entity_id_hook: - saml_settings["entityid"] = run_hook(get_custom_entity_id_hook, tenant_id=tenant_id) + saml_settings["entityid"] = run_hook(get_custom_entity_id_hook, tenant_id=tenant_id) # type: ignore name_id_format = saml2_auth_settings.get("NAME_ID_FORMAT") if name_id_format: