Skip to content

Commit

Permalink
Merge pull request #788 from UW-GAC/maint/allauth-upgrade
Browse files Browse the repository at this point in the history
django-allauth upgrade
  • Loading branch information
jmcarson authored Oct 25, 2024
2 parents 7c5df7a + 7277245 commit b590c2a
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 98 deletions.
1 change: 1 addition & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
"maintenance_mode.middleware.MaintenanceModeMiddleware",
"simple_history.middleware.HistoryRequestMiddleware",
"django_htmx.middleware.HtmxMiddleware",
"allauth.account.middleware.AccountMiddleware",
]

# STATIC
Expand Down
12 changes: 11 additions & 1 deletion primed/drupal_oauth_provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from allauth.account.models import EmailAddress
from allauth.socialaccount import app_settings, providers
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

from .views import CustomAdapter

logger = logging.getLogger(__name__)

DRUPAL_PROVIDER_ID = "drupal_oauth_provider"
Expand All @@ -24,9 +27,16 @@ class CustomAccount(ProviderAccount):


class CustomProvider(OAuth2Provider):
id = "drupal_oauth_provider"
id = DRUPAL_PROVIDER_ID
name = OVERRIDE_NAME
account_class = CustomAccount
oauth2_adapter_class = CustomAdapter
supports_token_authentication = True

def __init__(self, request, app=None):
if app is None:
app = get_adapter().get_app(request, self.id)
super().__init__(request, app=app)

def extract_uid(self, data):
return str(data["sub"])
Expand Down
244 changes: 225 additions & 19 deletions primed/drupal_oauth_provider/tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
import base64
import copy
import datetime
import hashlib
import json
import sys
from urllib.parse import parse_qs, urlparse

import jwt
import requests
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialAccount, SocialApp, SocialToken
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.messages.storage.fallback import FallbackStorage
from django.contrib.sites.models import Site
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory
from django.test.utils import override_settings
from django.urls import reverse

from .provider import CustomProvider

Expand Down Expand Up @@ -65,25 +81,43 @@ def sign_id_token(payload):


# Mocked version of the test data from /oauth/jwks
KEY_SERVER_RESP_JSON = json.dumps(
{
"keys": [
{
"kty": TESTING_JWT_KEYSET["kty"],
"n": TESTING_JWT_KEYSET["n"],
"e": TESTING_JWT_KEYSET["e"],
}
]
}
)
KEY_SERVER_RESP = {
"keys": [
{
"kty": TESTING_JWT_KEYSET["kty"],
"n": TESTING_JWT_KEYSET["n"],
"e": TESTING_JWT_KEYSET["e"],
}
]
}
KEY_SERVER_RESP_INVALID = copy.deepcopy(KEY_SERVER_RESP)
KEY_SERVER_RESP_INVALID["keys"][0]["kty"] = "nuts"
KEY_SERVER_RESP_JSON = json.dumps(KEY_SERVER_RESP)
KEY_SERVER_RESP_JSON_INVALID = json.dumps(KEY_SERVER_RESP_INVALID)
print(f"KEY_RESP_VALID: {KEY_SERVER_RESP_JSON}", file=sys.stderr)


# disable token storing for testing as it conflicts with drupals use
# of tokens for user info
@override_settings(SOCIALACCOUNT_STORE_TOKENS=False)
@override_settings(SOCIALACCOUNT_STORE_TOKENS=True)
class CustomProviderTests(OAuth2TestsMixin, TestCase):
provider_id = CustomProvider.id

def setUp(self):
super(CustomProviderTests, self).setUp()
self.factory = RequestFactory()
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
User = get_user_model()
user = User.objects.create_user("testuser", "[email protected]", "testpw")
self.client.login(username="testuser", password="testpw")
self.setup_time = datetime.datetime.now(datetime.timezone.utc)

# Create a social account for testing
self.social_account = SocialAccount.objects.create(
provider=self.provider.id, user=user, uid="1234", extra_data={}
)

# Provide two mocked responses, first is to the public key request
# second is used for the profile request for extra data
def get_mocked_response(self):
Expand All @@ -101,21 +135,83 @@ def get_mocked_response(self):
),
]

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
now = datetime.datetime.now(datetime.timezone.utc)
def login(self, resp_mock=None, process="login", with_refresh_token=True):
"""
Unfortunately due to how our provider works we need to alter
this test login function as the default one fails.
"""
with self.mocked_response():
resp = self.client.post(self.provider.get_login_url(self.request, process=process))
p = urlparse(resp["location"])
q = parse_qs(p.query)
pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get(
"OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default
)

self.assertEqual("code_challenge" in q, pkce_enabled)
self.assertEqual("code_challenge_method" in q, pkce_enabled)
if pkce_enabled:
code_challenge = q["code_challenge"][0]
self.assertEqual(q["code_challenge_method"][0], "S256")

complete_url = self.provider.get_callback_url()
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
response_json = self.get_login_response_json(with_refresh_token=with_refresh_token)

resp_mocks = resp_mock if isinstance(resp_mock, list) else ([resp_mock] if resp_mock is not None else [])

with self.mocked_response(
MockedResponse(200, response_json, {"content-type": "application/json"}),
*resp_mocks,
):
resp = self.client.get(complete_url, self.get_complete_parameters(q))

# Find the access token POST request, and assert that it contains
# the correct code_verifier if and only if PKCE is enabled
request_calls = requests.Session.request.call_args_list

for args, kwargs in request_calls:
data = kwargs.get("data", {})
if (
args
and args[0] == "POST"
and isinstance(data, dict)
and data.get("redirect_uri", "").endswith(complete_url)
):
self.assertEqual("code_verifier" in data, pkce_enabled)

if pkce_enabled:
hashed_code_verifier = hashlib.sha256(data["code_verifier"].encode("ascii"))
expected_code_challenge = (
base64.urlsafe_b64encode(hashed_code_verifier.digest()).rstrip(b"=").decode()
)
self.assertEqual(code_challenge, expected_code_challenge)

return resp

def get_id_token(self):
app = get_adapter().get_app(request=None, provider=self.provider_id)
allowed_audience = app.client_id
id_token = sign_id_token(
return sign_id_token(
{
"exp": now + datetime.timedelta(hours=1),
"iat": now,
"exp": self.setup_time + datetime.timedelta(hours=1),
"iat": self.setup_time,
"aud": allowed_audience,
"scope": ["authenticated", "oauth_client_user"],
"sub": 20122,
}
)

def get_access_token(self) -> str:
return self.get_id_token()

def get_expected_to_str(self):
return "[email protected]"

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
id_token = self.get_id_token()
response_data = {
"access_token": id_token,
"expires_in": 3600,
Expand All @@ -125,3 +221,113 @@ def get_login_response_json(self, with_refresh_token=True):
if with_refresh_token:
response_data["refresh_token"] = "testrf"
return json.dumps(response_data)

def test_authentication_error(self):
# Create a request

request = self.factory.get(reverse("drupal_oauth_provider_login"))

# Add session and messages middleware
from django.contrib.sessions.middleware import SessionMiddleware

middleware = SessionMiddleware(lambda x: x)
middleware.process_request(request)
request.session.save()

# Add messages support

messages = FallbackStorage(request)
setattr(request, "_messages", messages)

# Create adapter instance
from primed.drupal_oauth_provider.views import CustomAdapter

adapter = CustomAdapter(request)
# Create a SocialToken instance
token = SocialToken(app=self.app, account=self.social_account, token="invalid_token")

with self.assertRaisesRegex(OAuth2Error, "Invalid id_token"):
# Simulate the error condition of a bad token
with self.mocked_response(self.get_mocked_response()[0]):
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"})

with self.assertRaisesRegex(OAuth2Error, "Error retrieving drupal public key"):
# Simulate the error condition of invalid json
with self.mocked_response(MockedResponse(200, "[lkjsdd]")):
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"})

with self.assertRaisesRegex(OAuth2Error, "failed to convert jwk"):
# Simulate the error condition of invalid jwk
with self.mocked_response(
MockedResponse(200, KEY_SERVER_RESP_JSON_INVALID),
):
adapter.complete_login(request, app=self.app, token=token, response={"error": "invalid_grant"})


class TestProviderConfig(TestCase):
def setUp(self):
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
current_site = Site.objects.get_current()
app = SocialApp.objects.create(
provider=CustomProvider.id,
name=CustomProvider.id,
client_id="app123id",
key=CustomProvider.id,
secret="dummy",
)
self.app = app
self.app.sites.add(current_site)

def test_custom_provider_no_app(self):
rf = RequestFactory()
request = rf.get("/fake-url/")
provider = CustomProvider(request)
assert provider.app is not None

def test_custom_provider_scope_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = None
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request, app=self.app).get_provider_scope_config()

def test_custom_provider_scope_config_not_list(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = {"not_a_list": 1}
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request, app=self.app).get_provider_scope_config()

def test_custom_provider_scope_detail_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [
{
"z_drupal_machine_name": "X",
"request_scope": True,
"django_group_name": "Z",
}
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request, app=self.app).get_provider_managed_scope_status()

def test_custom_provider_has_scope(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = [
{
"drupal_machine_name": "X",
"request_scope": True,
"django_group_name": "Z",
}
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
CustomProvider(request, app=self.app).get_provider_managed_scope_status(scopes_granted=["X"])
11 changes: 3 additions & 8 deletions primed/drupal_oauth_provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
OAuth2LoginView,
)

from .provider import CustomProvider

logger = logging.getLogger(__name__)


class CustomAdapter(OAuth2Adapter):
provider_id = CustomProvider.id
provider_id = "drupal_oauth_provider"

provider_settings = app_settings.PROVIDERS.get(provider_id, {})

Expand Down Expand Up @@ -60,7 +58,8 @@ def get_public_key(self, headers):
try:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(public_key_jwk))
except Exception as e:
logger.error(f"[get_public_key] failed to convert jwk to public key {e}")
logger.error(f"[get_public_key] failed to convert jwk {public_key_jwk} to public key {e}")
raise OAuth2Error(f"[get_public_key] failed to convert jwk {public_key_jwk} to public key {e}")
else:
return public_key

Expand All @@ -74,7 +73,6 @@ def get_scopes_from_token(self, id_token, headers):
scopes = None

try:
unverified_header = jwt.get_unverified_header(id_token.token)
token_payload = jwt.decode(
id_token.token,
public_key,
Expand All @@ -85,9 +83,6 @@ def get_scopes_from_token(self, id_token, headers):
except jwt.PyJWTError as e:
logger.error(f"Invalid id_token {e} {id_token.token}")
raise OAuth2Error("Invalid id_token") from e
except Exception as e:
logger.error(f"Other exception parsing token {e} header {unverified_header} token {id_token}")
raise OAuth2Error("Error when decoding token {e}")
else:
scopes = token_payload.get("scope")

Expand Down
4 changes: 2 additions & 2 deletions primed/templates/socialaccount/authentication_error.html
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{% extends "socialaccount/base.html" %}
{% extends "base.html" %}

{% load i18n %}

{% block head_title %}{% trans "Social Network Login Failure" %}{% endblock %}
{% block title %}{% trans "Social Network Login Failure" %}{% endblock %}

{% block content %}
<h1>{% trans "Social Network Login Failure" %}</h1>
Expand Down
Loading

0 comments on commit b590c2a

Please sign in to comment.