Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

django-allauth upgrade integration fixes #769

Merged
merged 8 commits into from
Oct 18, 2024
1 change: 1 addition & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"maintenance_mode.middleware.MaintenanceModeMiddleware",
"simple_history.middleware.HistoryRequestMiddleware",
"django_htmx.middleware.HtmxMiddleware",
"allauth.account.middleware.AccountMiddleware",
]

# STATIC
Expand Down
10 changes: 10 additions & 0 deletions gregor_django/drupal_oauth_provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from allauth.account.models import EmailAddress
from allauth.socialaccount import app_settings, providers
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

from .views import CustomAdapter

logger = logging.getLogger(__name__)

DRUPAL_PROVIDER_ID = "drupal_oauth_provider"
Expand All @@ -27,6 +30,13 @@ class CustomProvider(OAuth2Provider):
id = DRUPAL_PROVIDER_ID
name = OVERRIDE_NAME
account_class = CustomAccount
oauth2_adapter_class = CustomAdapter
supports_token_authentication = True

def __init__(self, request, app=None):
if app is None:
app = get_adapter().get_app(request, self.id)
super().__init__(request, app=app)

def extract_uid(self, data):
return str(data["sub"])
Expand Down
121 changes: 110 additions & 11 deletions gregor_django/drupal_oauth_provider/tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import base64
import datetime
import hashlib
import json
from urllib.parse import parse_qs, urlparse

import jwt
import requests
from allauth.socialaccount import app_settings
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialApp
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.sites.models import Site
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory
from django.test.utils import override_settings
Expand Down Expand Up @@ -83,10 +91,19 @@ def sign_id_token(payload):

# disable token storing for testing as it conflicts with drupals use
# of tokens for user info
@override_settings(SOCIALACCOUNT_STORE_TOKENS=False)
@override_settings(SOCIALACCOUNT_STORE_TOKENS=True)
class CustomProviderTests(OAuth2TestsMixin, TestCase):
provider_id = CustomProvider.id

def setUp(self):
super(CustomProviderTests, self).setUp()
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
User = get_user_model()
User.objects.create_user("testuser", "[email protected]", "testpw")
self.client.login(username="testuser", password="testpw")
self.setup_time = datetime.datetime.now(datetime.timezone.utc)

# Provide two mocked responses, first is to the public key request
# second is used for the profile request for extra data
def get_mocked_response(self):
Expand All @@ -104,21 +121,83 @@ def get_mocked_response(self):
),
]

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
now = datetime.datetime.now(datetime.timezone.utc)
def login(self, resp_mock=None, process="login", with_refresh_token=True):
"""
Unfortunately due to how our provider works we need to alter
this test login function as the default one fails.
"""
with self.mocked_response():
resp = self.client.post(self.provider.get_login_url(self.request, process=process))
p = urlparse(resp["location"])
q = parse_qs(p.query)
pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get(
"OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default
)

self.assertEqual("code_challenge" in q, pkce_enabled)
self.assertEqual("code_challenge_method" in q, pkce_enabled)
if pkce_enabled:
code_challenge = q["code_challenge"][0]
self.assertEqual(q["code_challenge_method"][0], "S256")

complete_url = self.provider.get_callback_url()
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
response_json = self.get_login_response_json(with_refresh_token=with_refresh_token)

resp_mocks = resp_mock if isinstance(resp_mock, list) else ([resp_mock] if resp_mock is not None else [])

with self.mocked_response(
MockedResponse(200, response_json, {"content-type": "application/json"}),
*resp_mocks,
):
resp = self.client.get(complete_url, self.get_complete_parameters(q))

# Find the access token POST request, and assert that it contains
# the correct code_verifier if and only if PKCE is enabled
request_calls = requests.Session.request.call_args_list

for args, kwargs in request_calls:
data = kwargs.get("data", {})
if (
args
and args[0] == "POST"
and isinstance(data, dict)
and data.get("redirect_uri", "").endswith(complete_url)
):
self.assertEqual("code_verifier" in data, pkce_enabled)

if pkce_enabled:
hashed_code_verifier = hashlib.sha256(data["code_verifier"].encode("ascii"))
expected_code_challenge = (
base64.urlsafe_b64encode(hashed_code_verifier.digest()).rstrip(b"=").decode()
)
self.assertEqual(code_challenge, expected_code_challenge)

return resp

def get_id_token(self):
app = get_adapter().get_app(request=None, provider=self.provider_id)
allowed_audience = app.client_id
id_token = sign_id_token(
return sign_id_token(
{
"exp": now + datetime.timedelta(hours=1),
"iat": now,
"exp": self.setup_time + datetime.timedelta(hours=1),
"iat": self.setup_time,
"aud": allowed_audience,
"scope": ["authenticated", "oauth_client_user"],
"sub": 20122,
}
)

def get_access_token(self) -> str:
return self.get_id_token()

def get_expected_to_str(self):
return "[email protected]"

# This login response mimics drupals in that it contains a set of scopes
# and the uid which has the name sub
def get_login_response_json(self, with_refresh_token=True):
id_token = self.get_id_token()
response_data = {
"access_token": id_token,
"expires_in": 3600,
Expand All @@ -131,14 +210,34 @@ def get_login_response_json(self, with_refresh_token=True):


class TestProviderConfig(TestCase):
def setUp(self):
# workaround to create a session. see:
# https://code.djangoproject.com/ticket/11475
current_site = Site.objects.get_current()
app = SocialApp.objects.create(
provider=CustomProvider.id,
name=CustomProvider.id,
client_id="app123id",
key=CustomProvider.id,
secret="dummy",
)
self.app = app
self.app.sites.add(current_site)

def test_custom_provider_no_app(self):
rf = RequestFactory()
request = rf.get("/fake-url/")
provider = CustomProvider(request)
assert provider.app is not None

def test_custom_provider_scope_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
rf = RequestFactory()
request = rf.get("/fake-url/")
custom_provider_settings["drupal_oauth_provider"]["SCOPES"] = None
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request).get_provider_scope_config()
CustomProvider(request, app=self.app).get_provider_scope_config()

def test_custom_provider_scope_detail_config(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
Expand All @@ -153,7 +252,7 @@ def test_custom_provider_scope_detail_config(self):
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
with self.assertRaises(ImproperlyConfigured):
CustomProvider(request).get_provider_managed_scope_status()
CustomProvider(request, app=self.app).get_provider_managed_scope_status()

def test_custom_provider_has_scope(self):
custom_provider_settings = settings.SOCIALACCOUNT_PROVIDERS
Expand All @@ -167,4 +266,4 @@ def test_custom_provider_has_scope(self):
}
]
with override_settings(SOCIALACCOUNT_PROVIDERS=custom_provider_settings):
CustomProvider(request).get_provider_managed_scope_status(scopes_granted=["X"])
CustomProvider(request, app=self.app).get_provider_managed_scope_status(scopes_granted=["X"])
4 changes: 1 addition & 3 deletions gregor_django/drupal_oauth_provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
OAuth2LoginView,
)

from .provider import CustomProvider

logger = logging.getLogger(__name__)


class CustomAdapter(OAuth2Adapter):
provider_id = CustomProvider.id
provider_id = "drupal_oauth_provider"

provider_settings = app_settings.PROVIDERS.get(provider_id, {})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{% extends "socialaccount/base.html" %}
{% extends "base.html" %}

{% load i18n %}

{% block head_title %}{% trans "Social Network Login Failure" %}{% endblock %}
{% block title %}{% trans "Social Network Login Failure" %}{% endblock %}

{% block content %}
<h1>{% trans "Social Network Login Failure" %}</h1>
Expand Down
15 changes: 9 additions & 6 deletions gregor_django/users/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ def update_user_info(self, user, extra_data: Dict):
user_changed = False
if user.name != full_name:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user} " f"name updated from {user.name} to {full_name}"
f"[SocialAccountAdatpter:update_user_info] user {user} " f"name updated from {user.name} to {full_name}"
)
user.name = full_name
user_changed = True
if user.username != drupal_username:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user} "
f"[SocialAccountAdatpter:update_user_info] user {user} "
f"username updated from {user.username} to {drupal_username}"
)
user.username = drupal_username
user_changed = True
if user.email != drupal_email:
logger.info(
f"[SocialAccountAdatpter:update_user_name] user {user}"
f"[SocialAccountAdatpter:update_user_info] user {user}"
f" email updated from {user.email} to {drupal_email}"
)
user.email = drupal_email
Expand Down Expand Up @@ -186,10 +186,13 @@ def update_user_data(self, sociallogin: Any):
self.update_user_partner_groups(user, extra_data)
self.update_user_groups(user, extra_data)

def authentication_error(self, request, provider_id, error, exception, extra_context):
def on_authentication_error(self, request, provider_id, error, exception, extra_context):
"""
Invoked when there is an error in auth cycle.
Log so we know what is going on.
"""
logger.error(f"[SocialAccountAdapter:authentication_error] Error {error} Exception: {exception}")
super().authentication_error(request, provider_id, error, exception, extra_context)
logger.error(
f"[SocialAccountAdapter:on_authentication_error] Provider: {provider_id} "
f"Error {error} Exception: {exception} extra {extra_context}"
)
super().on_authentication_error(request, provider_id, error, exception, extra_context)
Loading