diff --git a/figures/sites.py b/figures/sites.py index 830f2346..9f987438 100644 --- a/figures/sites.py +++ b/figures/sites.py @@ -12,6 +12,7 @@ from __future__ import absolute_import from django.contrib.auth import get_user_model +from django.contrib.sites import shortcuts as sites_shortcuts from django.contrib.sites.models import Site from django.conf import settings @@ -304,6 +305,24 @@ def _get_all_sites(): return Site.objects.all() +def get_requested_site(request): + """ + From a request return the corresponding site. + + This functions makes use of the `REQUESTED_SITE_BACKEND` setting if configured, otherwise + it defaults to Django's get_current_site(). + + :return Site + """ + backend_path = settings.ENV_TOKENS['FIGURES'].get('REQUESTED_SITE_BACKEND') + if backend_path: + requested_site_backend = import_from_path(backend_path) + requested_site = requested_site_backend(request) + else: + requested_site = sites_shortcuts.get_current_site(request) + return requested_site + + def get_sites(): """ Get a list of sites for Figures purposes in a configurable manner. diff --git a/figures/views.py b/figures/views.py index ad19e1db..d9ec7eed 100644 --- a/figures/views.py +++ b/figures/views.py @@ -7,7 +7,6 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.decorators import login_required, user_passes_test -import django.contrib.sites.shortcuts from django.contrib.sites.models import Site from django.http import HttpResponseRedirect from django.shortcuts import get_object_or_404, render @@ -170,7 +169,7 @@ class CourseOverviewViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): lookup_value_regex = settings.COURSE_ID_PATTERN def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = figures.sites.get_courses_for_site(site) return queryset @@ -194,7 +193,7 @@ def get_course_key(self, course_id_str): def retrieve(self, request, *args, **kwargs): course_key = self.get_course_key( kwargs.get('pk', '')) - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(self.request) if figures.helpers.is_multisite(): if site != figures.sites.get_site_for_course(course_key): # Raising NotFound instead of PermissionDenied @@ -243,7 +242,7 @@ class UserIndexViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): filter_class = UserFilterSet def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = figures.sites.get_users_for_site(site) return queryset @@ -256,7 +255,7 @@ class CourseEnrollmentViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): filter_class = CourseEnrollmentFilter def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = figures.sites.get_course_enrollments_for_site(site) return queryset @@ -270,7 +269,7 @@ class CourseDailyMetricsViewSet(CommonAuthMixin, viewsets.ModelViewSet): filter_class = CourseDailyMetricsFilter def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = CourseDailyMetrics.objects.filter(site=site) return queryset @@ -284,7 +283,7 @@ class SiteDailyMetricsViewSet(CommonAuthMixin, viewsets.ModelViewSet): filter_class = SiteDailyMetricsFilter def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = SiteDailyMetrics.objects.filter(site=site) return queryset @@ -316,7 +315,7 @@ def get(self, request, format=None): # pylint: disable=redefined-builtin ''' Does not yet support multi-tenancy ''' - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = request.query_params.get('date_for') data = self.metrics_method(site=site, date_for=date_for) @@ -348,7 +347,7 @@ class GeneralUserDataViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): ordering_fields = ['username', 'email', 'profile__name', 'is_active', 'date_joined'] def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = figures.sites.get_users_for_site(site) return queryset @@ -361,13 +360,13 @@ class LearnerDetailsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): filter_class = UserFilterSet def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = figures.sites.get_users_for_site(site) return queryset def get_serializer_context(self): context = super(LearnerDetailsViewSet, self).get_serializer_context() - context['site'] = django.contrib.sites.shortcuts.get_current_site(self.request) + context['site'] = figures.sites.get_requested_site(self.request) return context @@ -426,7 +425,7 @@ def get_queryset(self): * If no valid course keys are found, then an empty list is returned from this view """ - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) course_keys = figures.sites.get_course_keys_for_site(site) try: param_course_keys = self.query_param_course_keys() @@ -441,7 +440,7 @@ def get_queryset(self): def get_serializer_context(self): context = super(LearnerMetricsViewSetV1, self).get_serializer_context() - context['site'] = django.contrib.sites.shortcuts.get_current_site(self.request) + context['site'] = figures.sites.get_requested_site(self.request) context['course_keys'] = self.query_param_course_keys() return context @@ -493,13 +492,13 @@ def get_queryset(self): * If no valid course keys are found, then an empty list is returned from this view """ - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) course_ids = self.query_param_course_ids() return site_users_enrollment_data(site=site, course_ids=course_ids) def get_serializer_context(self): context = super(LearnerMetricsViewSetV2, self).get_serializer_context() - context['site'] = django.contrib.sites.shortcuts.get_current_site(self.request) + context['site'] = figures.sites.get_requested_site(self.request) context['course_ids'] = self.query_param_course_ids() return context @@ -521,7 +520,7 @@ class EnrollmentMetricsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): filter_class = EnrollmentMetricsFilter def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = LearnerCourseGradeMetrics.objects.filter(site=site) return queryset @@ -534,7 +533,7 @@ def completed_ids(self, request): The default router does not support hyphen in the custom action, so we need to use the underscore until we implement a custom router """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) qs = self.model.objects.completed_ids_for_site(site=site) page = self.paginate_queryset(qs) if page is not None: @@ -552,7 +551,7 @@ def completed(self, request): Return matching LearnerCourseGradeMetric rows that have completed enrollments """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) qs = self.model.objects.completed_for_site(site=site) page = self.paginate_queryset(qs) if page is not None: @@ -582,7 +581,7 @@ def site_course_helper(self, pk): except InvalidKeyError: raise NotFound() - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) if figures.helpers.is_multisite(): if site != figures.sites.get_site_for_course(course_id): raise NotFound() @@ -609,7 +608,7 @@ def list(self, request): TODO: NEXT Add query params to get data from previous months TODO: Add paginagation """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) course_keys = figures.sites.get_course_keys_for_site(site) date_for = datetime.utcnow().date() month_for = '{}/{}'.format(date_for.month, date_for.year) @@ -711,13 +710,13 @@ def list(self, request): Returns site metrics data for current month """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) data = metrics.get_current_month_site_metrics(site) return Response(data) @list_route() def registered_users(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = datetime.utcnow().date() months_back = 6 @@ -735,7 +734,7 @@ def new_users(self, request): """ TODO: Rename the metrics module function to "new_users" to match this """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = datetime.utcnow().date() months_back = 6 @@ -750,7 +749,7 @@ def new_users(self, request): @list_route() def course_completions(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = datetime.utcnow().date() months_back = 6 @@ -765,7 +764,7 @@ def course_completions(self, request): @list_route() def course_enrollments(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = datetime.utcnow().date() months_back = 6 @@ -780,7 +779,7 @@ def course_enrollments(self, request): @list_route() def site_courses(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) date_for = datetime.utcnow().date() months_back = 6 @@ -795,7 +794,7 @@ def site_courses(self, request): @list_route() def active_users(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) months_back = 6 active_users = metrics.get_site_mau_history_metrics(site=site, months_back=months_back) @@ -816,7 +815,7 @@ def get_queryset(self): def retrieve(self, request, **kwargs): course_id_str = kwargs.get('pk', '') course_key = CourseKey.from_string(course_id_str.replace(' ', '+')) - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) if figures.helpers.is_multisite(): if site != figures.sites.get_site_for_course(course_key): @@ -827,7 +826,7 @@ def retrieve(self, request, **kwargs): return Response(serializer.data) def list(self, request): - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) course_overviews = figures.sites.get_courses_for_site(site) data = [] for co in course_overviews: @@ -858,7 +857,7 @@ def list(self, request): We use list instead of retrieve because retrieve requires a resource identifier, like a PK """ - site = django.contrib.sites.shortcuts.get_current_site(request) + site = figures.sites.get_requested_site(request) data = retrieve_live_site_mau_data(site) serializer = self.serializer_class(data) return Response(serializer.data) @@ -872,7 +871,7 @@ class CourseMauMetricsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): lookup_value_regex = settings.COURSE_ID_PATTERN def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = CourseMauMetrics.objects.filter(site=site) return queryset @@ -885,7 +884,7 @@ class SiteMauMetricsViewSet(CommonAuthMixin, viewsets.ReadOnlyModelViewSet): filter_class = SiteMauMetricsFilter def get_queryset(self): - site = django.contrib.sites.shortcuts.get_current_site(self.request) + site = figures.sites.get_requested_site(self.request) queryset = SiteMauMetrics.objects.filter(site=site) return queryset diff --git a/tests/test_sites.py b/tests/test_sites.py index 7b6104e0..a86fe7eb 100644 --- a/tests/test_sites.py +++ b/tests/test_sites.py @@ -437,6 +437,54 @@ def test_student_modules_for_course_enrollment(monkeypatch): assert set(sm) == set(ce_sm) +@pytest.mark.django_db +def test_get_requested_site_default_behaviour(settings): + """ + Test `get_requested_site` returns Django's get_current_site() by default. + """ + example_site = Site.objects.get() # gets the example site + settings.SITE_ID = example_site.id + + current_site = figures.sites.get_requested_site(request=mock.Mock()) + assert current_site == example_site, 'Use Django\'s get_current_site().' + + +@pytest.mark.django_db +def test_get_requested_site_custom_backend(settings): + """ + Test `get_requested_site` can use custom backends. + """ + orange_site = SiteFactory.create(name='orange site') + + settings.ENV_TOKENS = { + 'FIGURES': { + 'REQUESTED_SITE_BACKEND': 'organizations:get_orange_site' + } + } + with mock.patch('organizations.get_orange_site', create=True, return_value=orange_site): + requested_site = figures.sites.get_requested_site(request=mock.Mock()) + assert requested_site == orange_site, 'Should use custom backend.' + + +@pytest.mark.django_db +def test_get_requested_site_broken_backend(settings): + """ + Test `get_requested_site` don't hide errors from custom backends. + + Figures should keep a simple backend implementation without attempting to fix errors in site configuration or + faulty backends. + """ + settings.ENV_TOKENS = { + 'FIGURES': { + 'REQUESTED_SITE_BACKEND': 'organizations:broken_backend' + } + } + with mock.patch('organizations.broken_backend', create=True, side_effect=RuntimeError): + with pytest.raises(RuntimeError): + # Should fail if the REQUESTED_SITE_BACKEND fails + figures.sites.get_requested_site(request=mock.Mock()) + + @pytest.mark.skipif(not organizations_support_sites(), reason='needed only in multisite mode') @pytest.mark.django_db def test_get_sites_default_behaviour():