Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom parameter support to hooks for multi tenant use cases #352

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ Some of the following settings are related to how this module operates. The rest
| **TRIGGER.CUSTOM\_CREATE\_JWT** | A hook function to create a custom JWT for the user. This method will be called instead of the `create_jwt_token` default function and should return the token. This method accepts one parameter: `user`. | `str` | `None` | `my_app.models.users.create_custom_token` |
| **TRIGGER.CUSTOM\_TOKEN\_QUERY** | A hook function to create a custom query params with the JWT for the user. This method will be called after `CUSTOM_CREATE_JWT` to populate a query and attach it to a URL; should return the query params containing the token (e.g., `?token=encoded.jwt.token`). This method accepts one parameter: `token`. | `str` | `None` | `my_app.models.users.get_custom_token_query` |
| **TRIGGER.GET\_CUSTOM\_ASSERTION\_URL** | A hook function to get the assertion URL dynamically. Useful when you have dynamic routing, multi-tenant setup and etc. Overrides `ASSERTION_URL`. | `str` | `None` | `my_app.utils.get_custom_assertion_url` |
| **TRIGGER.GET\_CUSTOM\_ENTITY\_ID** | A hook function to get the Entity ID dynamically. Useful when you have dynamic routing, multi-tenant setup and etc. Overrides `ENTITY_ID`. | `str` | `None` | `my_app.utils.get_custom_entity_id_url` |
| **TRIGGER.GET\_CUSTOM\_FRONTEND\_URL** | A hook function to get a dynamic `FRONTEND_URL` dynamically (see below for more details). Overrides `FRONTEND_URL`. Acceots one parameter: `relay_state`. | `str` | `None` | `my_app.utils.get_custom_frontend_url` |
| **ASSERTION\_URL** | A URL to validate incoming SAML responses against. By default, `django-saml2-auth` will validate the SAML response's Service Provider address against the actual HTTP request's host and scheme. If this value is set, it will validate against `ASSERTION_URL` instead - perfect for when Django is running behind a reverse proxy. This will only allow to customize the domain part of the URL, for more customization use `GET_CUSTOM_ASSERTION_URL`. | `str` | `None` | `https://example.com` |
| **ENTITY\_ID** | The optional entity ID string to be passed in the 'Issuer' element of authentication request, if required by the IDP. | `str` | `None` | `https://exmaple.com/sso/acs` |
Expand Down
18 changes: 13 additions & 5 deletions django_saml2_auth/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,17 @@ def get_metadata(
)


def get_custom_acs_url() -> Optional[str]:
def get_custom_acs_url(tenant_id: Optional[str] = None) -> Optional[str]:
get_custom_acs_url_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ASSERTION_URL")
return run_hook(get_custom_acs_url_hook) if get_custom_acs_url_hook else None
return run_hook(get_custom_acs_url_hook, tenant_id=tenant_id) if get_custom_acs_url_hook else None # type: ignore


def get_saml_client(
domain: str,
acs: Callable[..., HttpResponse],
user_id: Optional[str] = None,
saml_response: Optional[str] = None,
tenant_id: Optional[str] = None,
) -> Optional[Saml2Client]:
"""Create a new Saml2Config object with the given config and return an initialized Saml2Client
using the config object. The settings are read from django settings key: SAML2_AUTH.
Expand All @@ -178,6 +179,7 @@ def get_saml_client(
to the given user identifier, either email or username. Defaults to None.
user_id (str or None): User identifier: username or email. Defaults to None.
saml_response (str or None): decoded XML SAML response.
tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None.

Raises:
SAMLAuthError: Re-raise any exception raised by Saml2Config or Saml2Client
Expand Down Expand Up @@ -206,7 +208,7 @@ def get_saml_client(
},
)

acs_url = get_custom_acs_url()
acs_url = get_custom_acs_url(tenant_id)
if not acs_url:
# get_reverse raises an exception if the view is not found, so we can safely ignore type errors
acs_url = domain + get_reverse([acs, "acs", "django_saml2_auth:acs"]) # type: ignore
Expand Down Expand Up @@ -246,6 +248,11 @@ def get_saml_client(
if entity_id:
saml_settings["entityid"] = entity_id

get_custom_entity_id_hook = dictor(settings.SAML2_AUTH, "TRIGGER.GET_CUSTOM_ENTITY_ID")

if get_custom_entity_id_hook:
saml_settings["entityid"] = run_hook(get_custom_entity_id_hook, tenant_id=tenant_id) # type: ignore

name_id_format = saml2_auth_settings.get("NAME_ID_FORMAT")
if name_id_format:
saml_settings["service"]["sp"]["name_id_policy_format"] = name_id_format
Expand Down Expand Up @@ -298,7 +305,7 @@ def get_saml_client(


def decode_saml_response(
request: HttpRequest, acs: Callable[..., HttpResponse]
request: HttpRequest, acs: Callable[..., HttpResponse], tenant_id: Optional[str] = None
) -> Union[HttpResponseRedirect, Optional[AuthnResponse], None]:
"""Given a request, the authentication response inside the SAML response body is parsed,
decoded and returned. If there are any issues parsing the request, the identity or the issuer,
Expand All @@ -307,6 +314,7 @@ def decode_saml_response(
Args:
request (HttpRequest): Django request object from identity provider (IdP)
acs (Callable[..., HttpResponse]): The acs endpoint
tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None.

Raises:
SAMLAuthError: There was no response from SAML client.
Expand Down Expand Up @@ -335,7 +343,7 @@ def decode_saml_response(
saml_response = base64.b64decode(response).decode("UTF-8")
except Exception:
saml_response = None
saml_client = get_saml_client(get_assertion_url(request), acs, saml_response=saml_response)
saml_client = get_saml_client(get_assertion_url(request), acs, saml_response=saml_response, tenant_id=tenant_id)
if not saml_client:
raise SAMLAuthError(
"There was an error creating the SAML client.",
Expand Down
32 changes: 31 additions & 1 deletion django_saml2_auth/tests/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ def get_metadata_auto_conf_urls(
def get_custom_assertion_url():
return "https://example.com/custom-tenant/acs"

def get_custom_assertion_url_with_param(tenant_id: str):
return f"https://example.com/{tenant_id}/acs"

GET_CUSTOM_ASSERTION_URL = "django_saml2_auth.tests.test_saml.get_custom_assertion_url"
def get_custom_entity_id_url(tenant_id: str):
return f"https://example.com/sso/{tenant_id}"

GET_CUSTOM_ASSERTION_URL = "django_saml2_auth.tests.test_saml.get_custom_assertion_url"
GET_CUSTOM_ASSERTION_URL_WITH_PARAM = "django_saml2_auth.tests.test_saml.get_custom_assertion_url_with_param"
GET_CUSTOM_ENTITY_ID = "django_saml2_auth.tests.test_saml.get_custom_entity_id_url"

def mock_extract_user_identity(
user: Dict[str, Optional[Any]], authn_response: AuthnResponse
Expand Down Expand Up @@ -479,6 +485,30 @@ def test_get_saml_client_success_with_custom_assertion_url_hook(settings: Settin
"sp",
)

def test_get_saml_client_success_with_custom_assertion_url_and_param_hook(settings: SettingsWrapper):
settings.SAML2_AUTH["METADATA_LOCAL_FILE_PATH"] = "django_saml2_auth/tests/metadata.xml"

settings.SAML2_AUTH["TRIGGER"]["GET_CUSTOM_ASSERTION_URL"] = GET_CUSTOM_ASSERTION_URL_WITH_PARAM

result = get_saml_client("example.com", acs, "[email protected]", tenant_id="custom-tenant")
assert result is not None
assert "https://example.com/custom-tenant/acs" in result.config.endpoint(
"assertion_consumer_service",
BINDING_HTTP_POST,
"sp",
)


def test_get_saml_client_success_with_custom_entity_id_hook(settings: SettingsWrapper):
settings.SAML2_AUTH["METADATA_LOCAL_FILE_PATH"] = "django_saml2_auth/tests/metadata.xml"

settings.SAML2_AUTH["TRIGGER"]["GET_CUSTOM_ENTITY_ID"] = GET_CUSTOM_ENTITY_ID

result = get_saml_client("example.com", acs, "[email protected]", tenant_id="custom-tenant")
assert result is not None
assert "https://example.com/sso/custom-tenant" == result.config.entityid


@responses.activate
def test_decode_saml_response_success(
settings: SettingsWrapper,
Expand Down
15 changes: 12 additions & 3 deletions django_saml2_auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def run_hook(
},
)
try:
result = getattr(cls, path[-1])(*args, **kwargs)
func: Callable = getattr(cls, path[-1])

if func.__code__.co_argcount > 0: # backwards compatibility with existing hooks with no parameters
result = func(*args, **kwargs)
else:
result = func()

except SAMLAuthError as exc:
# Re-raise the exception
raise exc
Expand Down Expand Up @@ -196,18 +202,21 @@ def handle_exception(exc: Exception, request: HttpRequest) -> HttpResponse:
return render(request, "django_saml2_auth/error.html", context=context, status=status)

@wraps(function)
def wrapper(request: HttpRequest) -> HttpResponse:
def wrapper(request: HttpRequest, **kwargs: Optional[Mapping[str, Any]]) -> HttpResponse:
"""Decorated function is wrapped and called here

Args:
request ([type]): [description]

Keyword Args:
**kwargs: Additional keyword arguments

Returns:
HttpResponse: Either a redirect or a response with error details
"""
result = None
try:
result = function(request)
result = function(request, **kwargs)
except (SAMLAuthError, Exception) as exc:
result = handle_exception(exc, request)
return result
Expand Down
11 changes: 6 additions & 5 deletions django_saml2_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def denied(request: HttpRequest) -> HttpResponse:

@csrf_exempt
@exception_handler
def acs(request: HttpRequest):
def acs(request: HttpRequest, tenant_id: Optional[str] = None):
"""Assertion Consumer Service is SAML terminology for the location at a ServiceProvider that
accepts <samlp:Response> messages (or SAML artifacts) for the purpose of establishing a session
based on an assertion. Assertion is a signed authentication request from identity provider (IdP)
to acs endpoint.

Args:
request (HttpRequest): Incoming request from identity provider (IdP) for authentication

tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None.
Exceptions:
SAMLAuthError: The target user is inactive.

Expand All @@ -106,7 +106,7 @@ def acs(request: HttpRequest):
"""
saml2_auth_settings = settings.SAML2_AUTH

authn_response = decode_saml_response(request, acs)
authn_response = decode_saml_response(request, acs, tenant_id=tenant_id)
# decode_saml_response() will raise SAMLAuthError if the response is invalid,
# so we can safely ignore the type check here.
user = extract_user_identity(authn_response) # type: ignore
Expand Down Expand Up @@ -263,12 +263,13 @@ def sp_initiated_login(request: HttpRequest) -> HttpResponseRedirect:


@exception_handler
def signin(request: HttpRequest) -> HttpResponseRedirect:
def signin(request: HttpRequest, tenant_id: Optional[str] = None) -> HttpResponseRedirect:
"""Custom sign-in view for SP-initiated SSO. This will be deprecated in the future
in favor of sp_initiated_login.

Args:
request (HttpRequest): Incoming request from service provider (SP) for authentication.
tenant_id (typing.Optional[str], optional): Tenant ID used for the custom ACS and Entity ID hooks. Defaults to None.

Raises:
SAMLAuthError: The next URL is invalid.
Expand Down Expand Up @@ -306,7 +307,7 @@ def signin(request: HttpRequest) -> HttpResponseRedirect:

request.session["login_next_url"] = next_url

saml_client = get_saml_client(get_assertion_url(request), acs)
saml_client = get_saml_client(get_assertion_url(request), acs, tenant_id=tenant_id)
_, info = saml_client.prepare_for_authenticate(relay_state=next_url) # type: ignore

redirect_url = dict(info["headers"]).get("Location", "")
Expand Down
Loading