Skip to content

Commit

Permalink
Merge pull request #769 from UW-GAC/maint/allauth-upgrade
Browse files Browse the repository at this point in the history
django-allauth upgrade integration fixes
  • Loading branch information
jmcarson authored Oct 18, 2024
2 parents a4323a8 + 2841a97 commit baedf4b
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 78 deletions.
1 change: 1 addition & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"maintenance_mode.middleware.MaintenanceModeMiddleware",
"simple_history.middleware.HistoryRequestMiddleware",
"django_htmx.middleware.HtmxMiddleware",
"allauth.account.middleware.AccountMiddleware",
]

# STATIC
Expand Down
10 changes: 10 additions & 0 deletions gregor_django/drupal_oauth_provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from allauth.account.models import EmailAddress
from allauth.socialaccount import app_settings, providers
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

from .views import CustomAdapter

logger = logging.getLogger(__name__)

DRUPAL_PROVIDER_ID = "drupal_oauth_provider"
Expand All @@ -27,6 +30,13 @@ class CustomProvider(OAuth2Provider):
id = DRUPAL_PROVIDER_ID
name = OVERRIDE_NAME
account_class = CustomAccount
oauth2_adapter_class = CustomAdapter
supports_token_authentication = True

def __init__(self, request, app=None):
if app is None:
app = get_adapter().get_app(request, self.id)
super().__init__(request, app=app)

def extract_uid(self, data):
return str(data["sub"])
Expand Down
121 changes: 110 additions & 11 deletions gregor_django/drupal_oauth_provider/tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import base64
import datetime
import hashlib
import json
from urllib.parse import parse_qs, urlparse

import jwt
import requests
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialApp
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.sites.models import Site
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory
from django.test.utils import override_settings
Expand Down Expand Up @@ -83,10 +91,19 @@ def sign_id_token(payload):

# disable token storing for testing as it conflicts with drupals use
# of tokens for user info
@override_settings(SOCIALACCOUNT_STORE_TOKENS=False)
@override_settings(SOCIALACCOUNT_STORE_TOKENS=True)
class CustomProviderTests(OAuth2TestsMixin, TestCase):
provider_id = CustomProvider.id

def setUp(self):
super(CustomProviderTests, self).setUp()
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
User = get_user_model()
User.objects.create_user("testuser", "[email protected]", "testpw")
self.client.login(username="testuser", password="testpw")
self.setup_time = datetime.datetime.now(datetime.timezone.utc)

# Provide two mocked responses, first is to the public key request
# second is used for the profile request for extra data
def get_mocked_response(self):
Expand All @@ -104,21 +121,83 @@ def get_mocked_response(self):
),
]

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
now = datetime.datetime.now(datetime.timezone.utc)
def login(self, resp_mock=None, process="login", with_refresh_token=True):
"""
Unfortunately due to how our provider works we need to alter
this test login function as the default one fails.
"""
with self.mocked_response():
resp = self.client.post(self.provider.get_login_url(self.request, process=process))
p = urlparse(resp["location"])
q = parse_qs(p.query)
pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get(
"OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default
)

self.assertEqual("code_challenge" in q, pkce_enabled)
self.assertEqual("code_challenge_method" in q, pkce_enabled)
if pkce_enabled:
code_challenge = q["code_challenge"][0]
self.assertEqual(q["code_challenge_method"][0], "S256")

complete_url = self.provider.get_callback_url()
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
response_json = self.get_login_response_json(with_refresh_token=with_refresh_token)

resp_mocks = resp_mock if isinstance(resp_mock, list) else ([resp_mock] if resp_mock is not None else [])

with self.mocked_response(
MockedResponse(200, response_json, {"content-type": "application/json"}),
*resp_mocks,
):
resp = self.client.get(complete_url, self.get_complete_parameters(q))

# Find the access token POST request, and assert that it contains
# the correct code_verifier if and only if PKCE is enabled
request_calls = requests.Session.request.call_args_list

for args, kwargs in request_calls:
data = kwargs.get("data", {})
if (
args
and args[0] == "POST"
and isinstance(data, dict)
and data.get("redirect_uri", "").endswith(complete_url)
):
self.assertEqual("code_verifier" in data, pkce_enabled)

if pkce_enabled:
hashed_code_verifier = hashlib.sha256(data["code_verifier"].encode("ascii"))
expected_code_challenge = (
base64.urlsafe_b64encode(hashed_code_verifier.digest()).rstrip(b"=").decode()
)
self.assertEqual(code_challenge, expected_code_challenge)

return resp

def get_id_token(self):
app = get_adapter().get_app(request=None, provider=self.provider_id)
allowed_audience = app.client_id
id_token = sign_id_token(
return sign_id_token(
{
"exp": now + datetime.timedelta(hours=1),
"iat": now,
"exp": self.setup_time + datetime.timedelta(hours=1),
"iat": self.setup_time,
"aud": allowed_audience,
"scope": ["authenticated", "oauth_client_user"],
"sub": 20122,
}
)

def get_access_token(self) -> str:
return self.get_id_token()

def get_expected_to_str(self):
return "[email protected]"

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
id_token = self.get_id_token()
response_data = {
"access_token": id_token,
"expires_in": 3600,
Expand All @@ -131,14 +210,34 @@ def get_login_response_json(self, with_refresh_token=True):


class TestProviderConfig(TestCase):
def setUp(self):
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
current_site = Site.objects.get_current()
app = SocialApp.objects.create(
provider=CustomProvider.id,
name=CustomProvider.id,
client_id="app123id",
key=CustomProvider.id,
secret="dummy",
)
self.app = app
self.app.sites.add(current_site)

def test_custom_provider_no_app(self):
rf = RequestFactory()
request = rf.get("/fake-url/")
provider = CustomProvider(request)
assert provider.app is not None

def test_custom_provider_scope_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = None
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request).get_provider_scope_config()
CustomProvider(request, app=self.app).get_provider_scope_config()

def test_custom_provider_scope_detail_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
Expand All @@ -153,7 +252,7 @@ def test_custom_provider_scope_detail_config(self):
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request).get_provider_managed_scope_status()
CustomProvider(request, app=self.app).get_provider_managed_scope_status()

def test_custom_provider_has_scope(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
Expand All @@ -167,4 +266,4 @@ def test_custom_provider_has_scope(self):
}
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
CustomProvider(request).get_provider_managed_scope_status(scopes_granted=["X"])
CustomProvider(request, app=self.app).get_provider_managed_scope_status(scopes_granted=["X"])
4 changes: 1 addition & 3 deletions gregor_django/drupal_oauth_provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
OAuth2LoginView,
)

from .provider import CustomProvider

logger = logging.getLogger(__name__)


class CustomAdapter(OAuth2Adapter):
provider_id = CustomProvider.id
provider_id = "drupal_oauth_provider"

provider_settings = app_settings.PROVIDERS.get(provider_id, {})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{% extends "socialaccount/base.html" %}
{% extends "base.html" %}

{% load i18n %}

{% block head_title %}{% trans "Social Network Login Failure" %}{% endblock %}
{% block title %}{% trans "Social Network Login Failure" %}{% endblock %}

{% block content %}
<h1>{% trans "Social Network Login Failure" %}</h1>
Expand Down
15 changes: 9 additions & 6 deletions gregor_django/users/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ def update_user_info(self, user, extra_data: Dict):
user_changed = False
if user.name != full_name:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user} " f"name updated from {user.name} to {full_name}"
f"[SocialAccountAdatpter:update_user_info] user {user} " f"name updated from {user.name} to {full_name}"
)
user.name = full_name
user_changed = True
if user.username != drupal_username:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user} "
f"[SocialAccountAdatpter:update_user_info] user {user} "
f"username updated from {user.username} to {drupal_username}"
)
user.username = drupal_username
user_changed = True
if user.email != drupal_email:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user}"
f"[SocialAccountAdatpter:update_user_info] user {user}"
f" email updated from {user.email} to {drupal_email}"
)
user.email = drupal_email
Expand Down Expand Up @@ -186,10 +186,13 @@ def update_user_data(self, sociallogin: Any):
self.update_user_partner_groups(user, extra_data)
self.update_user_groups(user, extra_data)

def authentication_error(self, request, provider_id, error, exception, extra_context):
def on_authentication_error(self, request, provider_id, error, exception, extra_context):
"""
Invoked when there is an error in auth cycle.
Log so we know what is going on.
"""
logger.error(f"[SocialAccountAdapter:authentication_error] Error {error} Exception: {exception}")
super().authentication_error(request, provider_id, error, exception, extra_context)
logger.error(
f"[SocialAccountAdapter:on_authentication_error] Provider: {provider_id} "
f"Error {error} Exception: {exception} extra {extra_context}"
)
super().on_authentication_error(request, provider_id, error, exception, extra_context)
Loading

0 comments on commit baedf4b

Please sign in to comment.