diff --git a/.github/workflows/pip-compile.yml b/.github/workflows/pip-compile.yml index 9f906bd1..72bd62a0 100644 --- a/.github/workflows/pip-compile.yml +++ b/.github/workflows/pip-compile.yml @@ -23,7 +23,7 @@ jobs: python-version: "3.10" - name: Update requirements files - uses: UW-GAC/pip-tools-actions/update-requirements-files@v0.1 + uses: UW-GAC/pip-tools-actions/update-requirements-files@v0.2 with: requirements_files: |- requirements/requirements.in diff --git a/config/settings/base.py b/config/settings/base.py index 3ed866c2..2f98275c 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -160,6 +160,7 @@ "maintenance_mode.middleware.MaintenanceModeMiddleware", "simple_history.middleware.HistoryRequestMiddleware", "django_htmx.middleware.HtmxMiddleware", + "allauth.account.middleware.AccountMiddleware", ] # STATIC diff --git a/primed/drupal_oauth_provider/provider.py b/primed/drupal_oauth_provider/provider.py index 92c84dd0..a43370cd 100644 --- a/primed/drupal_oauth_provider/provider.py +++ b/primed/drupal_oauth_provider/provider.py @@ -2,11 +2,14 @@ from allauth.account.models import EmailAddress from allauth.socialaccount import app_settings, providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.providers.base import ProviderAccount from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from .views import CustomAdapter + logger = logging.getLogger(__name__) DRUPAL_PROVIDER_ID = "drupal_oauth_provider" @@ -24,9 +27,16 @@ class CustomAccount(ProviderAccount): class CustomProvider(OAuth2Provider): - id = "drupal_oauth_provider" + id = DRUPAL_PROVIDER_ID name = OVERRIDE_NAME account_class = CustomAccount + oauth2_adapter_class = CustomAdapter + supports_token_authentication = True + + def __init__(self, request, app=None): + if app is None: + app = get_adapter().get_app(request, self.id) + super().__init__(request, app=app) def extract_uid(self, data): return str(data["sub"]) diff --git a/primed/drupal_oauth_provider/tests.py b/primed/drupal_oauth_provider/tests.py index 5ccfffbc..e59eb6f2 100644 --- a/primed/drupal_oauth_provider/tests.py +++ b/primed/drupal_oauth_provider/tests.py @@ -1,11 +1,27 @@ +import base64 +import copy import datetime +import hashlib import json +import sys +from urllib.parse import parse_qs, urlparse import jwt +import requests +from allauth.socialaccount import app_settings from allauth.socialaccount.adapter import get_adapter +from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken +from allauth.socialaccount.providers.oauth2.client import OAuth2Error from allauth.socialaccount.tests import OAuth2TestsMixin from allauth.tests import MockedResponse, TestCase +from django.conf import settings +from django.contrib.auth import get_user_model +from django.contrib.messages.storage.fallback import FallbackStorage +from django.contrib.sites.models import Site +from django.core.exceptions import ImproperlyConfigured +from django.test import RequestFactory from django.test.utils import override_settings +from django.urls import reverse from .provider import CustomProvider @@ -65,25 +81,43 @@ def sign_id_token(payload): # Mocked version of the test data from /oauth/jwks -KEY_SERVER_RESP_JSON = json.dumps( - { - "keys": [ - { - "kty": TESTING_JWT_KEYSET["kty"], - "n": TESTING_JWT_KEYSET["n"], - "e": TESTING_JWT_KEYSET["e"], - } - ] - } -) +KEY_SERVER_RESP = { + "keys": [ + { + "kty": TESTING_JWT_KEYSET["kty"], + "n": TESTING_JWT_KEYSET["n"], + "e": TESTING_JWT_KEYSET["e"], + } + ] +} +KEY_SERVER_RESP_INVALID = copy.deepcopy(KEY_SERVER_RESP) +KEY_SERVER_RESP_INVALID["keys"][0]["kty"] = "nuts" +KEY_SERVER_RESP_JSON = json.dumps(KEY_SERVER_RESP) +KEY_SERVER_RESP_JSON_INVALID = json.dumps(KEY_SERVER_RESP_INVALID) +print(f"KEY_RESP_VALID: {KEY_SERVER_RESP_JSON}", file=sys.stderr) # disable token storing for testing as it conflicts with drupals use # of tokens for user info -@override_settings(SOCIALACCOUNT_STORE_TOKENS=False) +@override_settings(SOCIALACCOUNT_STORE_TOKENS=True) class CustomProviderTests(OAuth2TestsMixin, TestCase): provider_id = CustomProvider.id + def setUp(self): + super(CustomProviderTests, self).setUp() + self.factory = RequestFactory() + # workaround to create a session. see: + # https://code.djangoproject.com/ticket/11475 + User = get_user_model() + user = User.objects.create_user("testuser", "testuser@testuser.com", "testpw") + self.client.login(username="testuser", password="testpw") + self.setup_time = datetime.datetime.now(datetime.timezone.utc) + + # Create a social account for testing + self.social_account = SocialAccount.objects.create( + provider=self.provider.id, user=user, uid="1234", extra_data={} + ) + # Provide two mocked responses, first is to the public key request # second is used for the profile request for extra data def get_mocked_response(self): @@ -101,21 +135,83 @@ def get_mocked_response(self): ), ] - # This login response mimics drupals in that it contains a set of scopes - # and the uid which has the name sub - def get_login_response_json(self, with_refresh_token=True): - now = datetime.datetime.now(datetime.timezone.utc) + def login(self, resp_mock=None, process="login", with_refresh_token=True): + """ + Unfortunately due to how our provider works we need to alter + this test login function as the default one fails. + """ + with self.mocked_response(): + resp = self.client.post(self.provider.get_login_url(self.request, process=process)) + p = urlparse(resp["location"]) + q = parse_qs(p.query) + pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get( + "OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default + ) + + self.assertEqual("code_challenge" in q, pkce_enabled) + self.assertEqual("code_challenge_method" in q, pkce_enabled) + if pkce_enabled: + code_challenge = q["code_challenge"][0] + self.assertEqual(q["code_challenge_method"][0], "S256") + + complete_url = self.provider.get_callback_url() + self.assertGreater(q["redirect_uri"][0].find(complete_url), 0) + response_json = self.get_login_response_json(with_refresh_token=with_refresh_token) + + resp_mocks = resp_mock if isinstance(resp_mock, list) else ([resp_mock] if resp_mock is not None else []) + + with self.mocked_response( + MockedResponse(200, response_json, {"content-type": "application/json"}), + *resp_mocks, + ): + resp = self.client.get(complete_url, self.get_complete_parameters(q)) + + # Find the access token POST request, and assert that it contains + # the correct code_verifier if and only if PKCE is enabled + request_calls = requests.Session.request.call_args_list + + for args, kwargs in request_calls: + data = kwargs.get("data", {}) + if ( + args + and args[0] == "POST" + and isinstance(data, dict) + and data.get("redirect_uri", "").endswith(complete_url) + ): + self.assertEqual("code_verifier" in data, pkce_enabled) + + if pkce_enabled: + hashed_code_verifier = hashlib.sha256(data["code_verifier"].encode("ascii")) + expected_code_challenge = ( + base64.urlsafe_b64encode(hashed_code_verifier.digest()).rstrip(b"=").decode() + ) + self.assertEqual(code_challenge, expected_code_challenge) + + return resp + + def get_id_token(self): app = get_adapter().get_app(request=None, provider=self.provider_id) allowed_audience = app.client_id - id_token = sign_id_token( + return sign_id_token( { - "exp": now + datetime.timedelta(hours=1), - "iat": now, + "exp": self.setup_time + datetime.timedelta(hours=1), + "iat": self.setup_time, "aud": allowed_audience, "scope": ["authenticated", "oauth_client_user"], "sub": 20122, } ) + + def get_access_token(self) -> str: + return self.get_id_token() + + def get_expected_to_str(self): + return "test@testmaster.net" + + # This login response mimics drupals in that it contains a set of scopes + # and the uid which has the name sub + def get_login_response_json(self, with_refresh_token=True): + id_token = self.get_id_token() response_data = { "access_token": id_token, "expires_in": 3600, @@ -125,3 +221,113 @@ def get_login_response_json(self, with_refresh_token=True): if with_refresh_token: response_data["refresh_token"] = "testrf" return json.dumps(response_data) + + def test_authentication_error(self): + # Create a request + + request = self.factory.get(reverse("drupal_oauth_provider_login")) + + # Add session and messages middleware + from django.contrib.sessions.middleware import SessionMiddleware + + middleware = SessionMiddleware(lambda x: x) + middleware.process_request(request) + request.session.save() + + # Add messages support + + messages = FallbackStorage(request) + setattr(request, "_messages", messages) + + # Create adapter instance + from primed.drupal_oauth_provider.views import CustomAdapter + + adapter = CustomAdapter(request) + # Create a SocialToken instance + token = SocialToken(app=self.app, account=self.social_account, token="invalid_token") + + with self.assertRaisesRegex(OAuth2Error, "Invalid id_token"): + # Simulate the error condition of a bad token + with self.mocked_response(self.get_mocked_response()[0]): + adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) + + with self.assertRaisesRegex(OAuth2Error, "Error retrieving drupal public key"): + # Simulate the error condition of invalid json + with self.mocked_response(MockedResponse(200, "[lkjsdd]")): + adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) + + with self.assertRaisesRegex(OAuth2Error, "failed to convert jwk"): + # Simulate the error condition of invalid jwk + with self.mocked_response( + MockedResponse(200, KEY_SERVER_RESP_JSON_INVALID), + ): + adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) + + +class TestProviderConfig(TestCase): + def setUp(self): + # workaround to create a session. see: + # https://code.djangoproject.com/ticket/11475 + current_site = Site.objects.get_current() + app = SocialApp.objects.create( + provider=CustomProvider.id, + name=CustomProvider.id, + client_id="app123id", + key=CustomProvider.id, + secret="dummy", + ) + self.app = app + self.app.sites.add(current_site) + + def test_custom_provider_no_app(self): + rf = RequestFactory() + request = rf.get("/fake-url/") + provider = CustomProvider(request) + assert provider.app is not None + + def test_custom_provider_scope_config(self): + custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS + rf = RequestFactory() + request = rf.get("/fake-url/") + custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = None + with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): + with self.assertRaises(ImproperlyConfigured): + CustomProvider(request, app=self.app).get_provider_scope_config() + + def test_custom_provider_scope_config_not_list(self): + custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS + rf = RequestFactory() + request = rf.get("/fake-url/") + custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = {"not_a_list": 1} + with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): + with self.assertRaises(ImproperlyConfigured): + CustomProvider(request, app=self.app).get_provider_scope_config() + + def test_custom_provider_scope_detail_config(self): + custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS + rf = RequestFactory() + request = rf.get("/fake-url/") + custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [ + { + "z_drupal_machine_name": "X", + "request_scope": True, + "django_group_name": "Z", + } + ] + with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): + with self.assertRaises(ImproperlyConfigured): + CustomProvider(request, app=self.app).get_provider_managed_scope_status() + + def test_custom_provider_has_scope(self): + custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS + rf = RequestFactory() + request = rf.get("/fake-url/") + custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [ + { + "drupal_machine_name": "X", + "request_scope": True, + "django_group_name": "Z", + } + ] + with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): + CustomProvider(request, app=self.app).get_provider_managed_scope_status(scopes_granted=["X"]) diff --git a/primed/drupal_oauth_provider/views.py b/primed/drupal_oauth_provider/views.py index c8a78575..16a69e7f 100644 --- a/primed/drupal_oauth_provider/views.py +++ b/primed/drupal_oauth_provider/views.py @@ -12,13 +12,11 @@ OAuth2LoginView, ) -from .provider import CustomProvider - logger = logging.getLogger(__name__) class CustomAdapter(OAuth2Adapter): - provider_id = CustomProvider.id + provider_id = "drupal_oauth_provider" provider_settings = app_settings.PROVIDERS.get(provider_id, {}) @@ -60,7 +58,8 @@ def get_public_key(self, headers): try: public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(public_key_jwk)) except Exception as e: - logger.error(f"[get_public_key] failed to convert jwk to public key {e}") + logger.error(f"[get_public_key] failed to convert jwk {public_key_jwk} to public key {e}") + raise OAuth2Error(f"[get_public_key] failed to convert jwk {public_key_jwk} to public key {e}") else: return public_key @@ -74,7 +73,6 @@ def get_scopes_from_token(self, id_token, headers): scopes = None try: - unverified_header = jwt.get_unverified_header(id_token.token) token_payload = jwt.decode( id_token.token, public_key, @@ -85,9 +83,6 @@ def get_scopes_from_token(self, id_token, headers): except jwt.PyJWTError as e: logger.error(f"Invalid id_token {e} {id_token.token}") raise OAuth2Error("Invalid id_token") from e - except Exception as e: - logger.error(f"Other exception parsing token {e} header {unverified_header} token {id_token}") - raise OAuth2Error("Error when decoding token {e}") else: scopes = token_payload.get("scope") diff --git a/primed/templates/socialaccount/authentication_error.html b/primed/templates/socialaccount/authentication_error.html index 747a4513..c4d9bd56 100644 --- a/primed/templates/socialaccount/authentication_error.html +++ b/primed/templates/socialaccount/authentication_error.html @@ -1,8 +1,8 @@ -{% extends "socialaccount/base.html" %} +{% extends "base.html" %} {% load i18n %} -{% block head_title %}{% trans "Social Network Login Failure" %}{% endblock %} +{% block title %}{% trans "Social Network Login Failure" %}{% endblock %} {% block content %}
- {{brand.name}} -
-{% endfor %} -{% endif %}