-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #788 from UW-GAC/maint/allauth-upgrade
django-allauth upgrade
- Loading branch information
Showing
11 changed files
with
385 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,27 @@ | ||
import base64 | ||
import copy | ||
import datetime | ||
import hashlib | ||
import json | ||
import sys | ||
from urllib.parse import parse_qs, urlparse | ||
|
||
import jwt | ||
import requests | ||
from allauth.socialaccount import app_settings | ||
from allauth.socialaccount.adapter import get_adapter | ||
from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken | ||
from allauth.socialaccount.providers.oauth2.client import OAuth2Error | ||
from allauth.socialaccount.tests import OAuth2TestsMixin | ||
from allauth.tests import MockedResponse, TestCase | ||
from django.conf import settings | ||
from django.contrib.auth import get_user_model | ||
from django.contrib.messages.storage.fallback import FallbackStorage | ||
from django.contrib.sites.models import Site | ||
from django.core.exceptions import ImproperlyConfigured | ||
from django.test import RequestFactory | ||
from django.test.utils import override_settings | ||
from django.urls import reverse | ||
|
||
from .provider import CustomProvider | ||
|
||
|
@@ -65,25 +81,43 @@ def sign_id_token(payload): | |
|
||
|
||
# Mocked version of the test data from /oauth/jwks | ||
KEY_SERVER_RESP_JSON = json.dumps( | ||
{ | ||
"keys": [ | ||
{ | ||
"kty": TESTING_JWT_KEYSET["kty"], | ||
"n": TESTING_JWT_KEYSET["n"], | ||
"e": TESTING_JWT_KEYSET["e"], | ||
} | ||
] | ||
} | ||
) | ||
KEY_SERVER_RESP = { | ||
"keys": [ | ||
{ | ||
"kty": TESTING_JWT_KEYSET["kty"], | ||
"n": TESTING_JWT_KEYSET["n"], | ||
"e": TESTING_JWT_KEYSET["e"], | ||
} | ||
] | ||
} | ||
KEY_SERVER_RESP_INVALID = copy.deepcopy(KEY_SERVER_RESP) | ||
KEY_SERVER_RESP_INVALID["keys"][0]["kty"] = "nuts" | ||
KEY_SERVER_RESP_JSON = json.dumps(KEY_SERVER_RESP) | ||
KEY_SERVER_RESP_JSON_INVALID = json.dumps(KEY_SERVER_RESP_INVALID) | ||
print(f"KEY_RESP_VALID: {KEY_SERVER_RESP_JSON}", file=sys.stderr) | ||
|
||
|
||
# disable token storing for testing as it conflicts with drupals use | ||
# of tokens for user info | ||
@override_settings(SOCIALACCOUNT_STORE_TOKENS=False) | ||
@override_settings(SOCIALACCOUNT_STORE_TOKENS=True) | ||
class CustomProviderTests(OAuth2TestsMixin, TestCase): | ||
provider_id = CustomProvider.id | ||
|
||
def setUp(self): | ||
super(CustomProviderTests, self).setUp() | ||
self.factory = RequestFactory() | ||
# workaround to create a session. see: | ||
# https://code.djangoproject.com/ticket/11475 | ||
User = get_user_model() | ||
user = User.objects.create_user("testuser", "[email protected]", "testpw") | ||
self.client.login(username="testuser", password="testpw") | ||
self.setup_time = datetime.datetime.now(datetime.timezone.utc) | ||
|
||
# Create a social account for testing | ||
self.social_account = SocialAccount.objects.create( | ||
provider=self.provider.id, user=user, uid="1234", extra_data={} | ||
) | ||
|
||
# Provide two mocked responses, first is to the public key request | ||
# second is used for the profile request for extra data | ||
def get_mocked_response(self): | ||
|
@@ -101,21 +135,83 @@ def get_mocked_response(self): | |
), | ||
] | ||
|
||
# This login response mimics drupals in that it contains a set of scopes | ||
# and the uid which has the name sub | ||
def get_login_response_json(self, with_refresh_token=True): | ||
now = datetime.datetime.now(datetime.timezone.utc) | ||
def login(self, resp_mock=None, process="login", with_refresh_token=True): | ||
""" | ||
Unfortunately due to how our provider works we need to alter | ||
this test login function as the default one fails. | ||
""" | ||
with self.mocked_response(): | ||
resp = self.client.post(self.provider.get_login_url(self.request, process=process)) | ||
p = urlparse(resp["location"]) | ||
q = parse_qs(p.query) | ||
pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get( | ||
"OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default | ||
) | ||
|
||
self.assertEqual("code_challenge" in q, pkce_enabled) | ||
self.assertEqual("code_challenge_method" in q, pkce_enabled) | ||
if pkce_enabled: | ||
code_challenge = q["code_challenge"][0] | ||
self.assertEqual(q["code_challenge_method"][0], "S256") | ||
|
||
complete_url = self.provider.get_callback_url() | ||
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0) | ||
response_json = self.get_login_response_json(with_refresh_token=with_refresh_token) | ||
|
||
resp_mocks = resp_mock if isinstance(resp_mock, list) else ([resp_mock] if resp_mock is not None else []) | ||
|
||
with self.mocked_response( | ||
MockedResponse(200, response_json, {"content-type": "application/json"}), | ||
*resp_mocks, | ||
): | ||
resp = self.client.get(complete_url, self.get_complete_parameters(q)) | ||
|
||
# Find the access token POST request, and assert that it contains | ||
# the correct code_verifier if and only if PKCE is enabled | ||
request_calls = requests.Session.request.call_args_list | ||
|
||
for args, kwargs in request_calls: | ||
data = kwargs.get("data", {}) | ||
if ( | ||
args | ||
and args[0] == "POST" | ||
and isinstance(data, dict) | ||
and data.get("redirect_uri", "").endswith(complete_url) | ||
): | ||
self.assertEqual("code_verifier" in data, pkce_enabled) | ||
|
||
if pkce_enabled: | ||
hashed_code_verifier = hashlib.sha256(data["code_verifier"].encode("ascii")) | ||
expected_code_challenge = ( | ||
base64.urlsafe_b64encode(hashed_code_verifier.digest()).rstrip(b"=").decode() | ||
) | ||
self.assertEqual(code_challenge, expected_code_challenge) | ||
|
||
return resp | ||
|
||
def get_id_token(self): | ||
app = get_adapter().get_app(request=None, provider=self.provider_id) | ||
allowed_audience = app.client_id | ||
id_token = sign_id_token( | ||
return sign_id_token( | ||
{ | ||
"exp": now + datetime.timedelta(hours=1), | ||
"iat": now, | ||
"exp": self.setup_time + datetime.timedelta(hours=1), | ||
"iat": self.setup_time, | ||
"aud": allowed_audience, | ||
"scope": ["authenticated", "oauth_client_user"], | ||
"sub": 20122, | ||
} | ||
) | ||
|
||
def get_access_token(self) -> str: | ||
return self.get_id_token() | ||
|
||
def get_expected_to_str(self): | ||
return "[email protected]" | ||
|
||
# This login response mimics drupals in that it contains a set of scopes | ||
# and the uid which has the name sub | ||
def get_login_response_json(self, with_refresh_token=True): | ||
id_token = self.get_id_token() | ||
response_data = { | ||
"access_token": id_token, | ||
"expires_in": 3600, | ||
|
@@ -125,3 +221,113 @@ def get_login_response_json(self, with_refresh_token=True): | |
if with_refresh_token: | ||
response_data["refresh_token"] = "testrf" | ||
return json.dumps(response_data) | ||
|
||
def test_authentication_error(self): | ||
# Create a request | ||
|
||
request = self.factory.get(reverse("drupal_oauth_provider_login")) | ||
|
||
# Add session and messages middleware | ||
from django.contrib.sessions.middleware import SessionMiddleware | ||
|
||
middleware = SessionMiddleware(lambda x: x) | ||
middleware.process_request(request) | ||
request.session.save() | ||
|
||
# Add messages support | ||
|
||
messages = FallbackStorage(request) | ||
setattr(request, "_messages", messages) | ||
|
||
# Create adapter instance | ||
from primed.drupal_oauth_provider.views import CustomAdapter | ||
|
||
adapter = CustomAdapter(request) | ||
# Create a SocialToken instance | ||
token = SocialToken(app=self.app, account=self.social_account, token="invalid_token") | ||
|
||
with self.assertRaisesRegex(OAuth2Error, "Invalid id_token"): | ||
# Simulate the error condition of a bad token | ||
with self.mocked_response(self.get_mocked_response()[0]): | ||
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) | ||
|
||
with self.assertRaisesRegex(OAuth2Error, "Error retrieving drupal public key"): | ||
# Simulate the error condition of invalid json | ||
with self.mocked_response(MockedResponse(200, "[lkjsdd]")): | ||
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) | ||
|
||
with self.assertRaisesRegex(OAuth2Error, "failed to convert jwk"): | ||
# Simulate the error condition of invalid jwk | ||
with self.mocked_response( | ||
MockedResponse(200, KEY_SERVER_RESP_JSON_INVALID), | ||
): | ||
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"}) | ||
|
||
|
||
class TestProviderConfig(TestCase): | ||
def setUp(self): | ||
# workaround to create a session. see: | ||
# https://code.djangoproject.com/ticket/11475 | ||
current_site = Site.objects.get_current() | ||
app = SocialApp.objects.create( | ||
provider=CustomProvider.id, | ||
name=CustomProvider.id, | ||
client_id="app123id", | ||
key=CustomProvider.id, | ||
secret="dummy", | ||
) | ||
self.app = app | ||
self.app.sites.add(current_site) | ||
|
||
def test_custom_provider_no_app(self): | ||
rf = RequestFactory() | ||
request = rf.get("/fake-url/") | ||
provider = CustomProvider(request) | ||
assert provider.app is not None | ||
|
||
def test_custom_provider_scope_config(self): | ||
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS | ||
rf = RequestFactory() | ||
request = rf.get("/fake-url/") | ||
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = None | ||
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): | ||
with self.assertRaises(ImproperlyConfigured): | ||
CustomProvider(request, app=self.app).get_provider_scope_config() | ||
|
||
def test_custom_provider_scope_config_not_list(self): | ||
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS | ||
rf = RequestFactory() | ||
request = rf.get("/fake-url/") | ||
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = {"not_a_list": 1} | ||
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): | ||
with self.assertRaises(ImproperlyConfigured): | ||
CustomProvider(request, app=self.app).get_provider_scope_config() | ||
|
||
def test_custom_provider_scope_detail_config(self): | ||
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS | ||
rf = RequestFactory() | ||
request = rf.get("/fake-url/") | ||
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [ | ||
{ | ||
"z_drupal_machine_name": "X", | ||
"request_scope": True, | ||
"django_group_name": "Z", | ||
} | ||
] | ||
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): | ||
with self.assertRaises(ImproperlyConfigured): | ||
CustomProvider(request, app=self.app).get_provider_managed_scope_status() | ||
|
||
def test_custom_provider_has_scope(self): | ||
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS | ||
rf = RequestFactory() | ||
request = rf.get("/fake-url/") | ||
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [ | ||
{ | ||
"drupal_machine_name": "X", | ||
"request_scope": True, | ||
"django_group_name": "Z", | ||
} | ||
] | ||
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings): | ||
CustomProvider(request, app=self.app).get_provider_managed_scope_status(scopes_granted=["X"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.