From 3a0aab8026e5c29c7f899486fa185a0ae7f33847 Mon Sep 17 00:00:00 2001 From: Hamza Shafique Date: Tue, 12 Nov 2024 20:02:40 +0500 Subject: [PATCH] perf: identify and prefetch N+1 queries in search/all for learner pathways --- .../api/v1/tests/test_views/test_search.py | 24 +++++++++----- .../documents/learner_pathway.py | 25 +++++++++++++-- .../apps/learner_pathway/api/serializers.py | 10 ++---- .../apps/learner_pathway/api/v1/urls.py | 4 +-- .../apps/learner_pathway/api/v1/views.py | 31 +++++++++++++++++-- .../apps/learner_pathway/models.py | 14 ++------- 6 files changed, 74 insertions(+), 34 deletions(-) diff --git a/course_discovery/apps/api/v1/tests/test_views/test_search.py b/course_discovery/apps/api/v1/tests/test_views/test_search.py index 3bd5b70a713..db3aa05ca6f 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_search.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_search.py @@ -766,13 +766,23 @@ def test_results_include_aggregation_key(self): @ddt.data(True, False) def test_learner_pathway_feature_flag(self, include_learner_pathways): """ Verify the include_learner_pathways feature flag works as expected.""" - LearnerPathwayStepFactory(pathway__partner=self.partner) + LearnerPathwayStepFactory.create_batch(5, pathway__partner=self.partner) pathways = LearnerPathway.objects.all() - assert pathways.count() == 1 + assert pathways.count() == 5 query = { 'include_learner_pathways': include_learner_pathways, } + if include_learner_pathways: + expected_result_count = pathways.count() + expected_query_count = 8 + else: + expected_result_count = 0 + expected_query_count = 4 + + with self.assertNumQueries(expected_query_count): + response = self.get_response(query, self.list_path) + response = self.get_response( query, self.list_path @@ -780,11 +790,11 @@ def test_learner_pathway_feature_flag(self, include_learner_pathways): assert response.status_code == 200 response_data = response.json() - if include_learner_pathways: - assert response_data['count'] == 1 - assert response_data['results'][0] == self.serialize_learner_pathway_search(pathways[0]) - else: - assert response_data['count'] == 0 + assert response_data['count'] == expected_result_count + + for pathway in pathways: + assert self.serialize_learner_pathway_search(pathway) in response.data['results'] + class LimitedAggregateSearchViewSetTests( diff --git a/course_discovery/apps/course_metadata/search_indexes/documents/learner_pathway.py b/course_discovery/apps/course_metadata/search_indexes/documents/learner_pathway.py index bda4f549f2c..06503550581 100644 --- a/course_discovery/apps/course_metadata/search_indexes/documents/learner_pathway.py +++ b/course_discovery/apps/course_metadata/search_indexes/documents/learner_pathway.py @@ -1,6 +1,9 @@ from django.conf import settings +from django.db.models import Prefetch from django_elasticsearch_dsl import Index, fields +from course_discovery.apps.course_metadata.choices import CourseRunStatus +from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.learner_pathway.choices import PathwayStatus from course_discovery.apps.learner_pathway.models import LearnerPathway @@ -50,10 +53,26 @@ def prepare_partner(self, obj): def prepare_published(self, obj): return obj.status == PathwayStatus.Active - def get_queryset(self, excluded_restriction_types=None): # pylint: disable=unused-argument + def get_queryset(self, excluded_restriction_types=None): + if excluded_restriction_types is None: + excluded_restriction_types = [] + + course_runs = CourseRun.objects.filter( + status=CourseRunStatus.Published + ).exclude( + restricted_run__restriction_type__in=excluded_restriction_types + ) + return super().get_queryset().prefetch_related( - 'steps', 'steps__learnerpathwaycourse_set', 'steps__learnerpathwayprogram_set', - 'steps__learnerpathwayblock_set', + 'steps', + Prefetch( + 'steps__learnerpathwaycourse_set__course__course_runs', + queryset=course_runs + ), + Prefetch( + 'steps__learnerpathwayprogram_set__program__courses__course_runs', + queryset=course_runs + ) ) def prepare_skill_names(self, obj): diff --git a/course_discovery/apps/learner_pathway/api/serializers.py b/course_discovery/apps/learner_pathway/api/serializers.py index a33724b831c..c7bd8fadb4d 100644 --- a/course_discovery/apps/learner_pathway/api/serializers.py +++ b/course_discovery/apps/learner_pathway/api/serializers.py @@ -20,12 +20,7 @@ class Meta: fields = ('key', 'course_runs') def get_course_runs(self, obj): - excluded_restriction_types = get_excluded_restriction_types(self.context['request']) - return list(obj.course.course_runs.filter( - status=CourseRunStatus.Published - ).exclude( - restricted_run__restriction_type__in=excluded_restriction_types - ).values('key')) + return [{'key': course_run.key} for course_run in obj.course.course_runs.all()] class LearnerPathwayCourseSerializer(LearnerPathwayCourseMinimalSerializer): @@ -87,8 +82,7 @@ def get_card_image_url(self, step): return program.card_image_url def get_courses(self, obj): - excluded_restriction_types = get_excluded_restriction_types(self.context['request']) - return obj.get_linked_courses_and_course_runs(excluded_restriction_types=excluded_restriction_types) + return obj.get_linked_courses_and_course_runs() class LearnerPathwayBlockSerializer(serializers.ModelSerializer): diff --git a/course_discovery/apps/learner_pathway/api/v1/urls.py b/course_discovery/apps/learner_pathway/api/v1/urls.py index d66de00f1ed..21a00182eee 100644 --- a/course_discovery/apps/learner_pathway/api/v1/urls.py +++ b/course_discovery/apps/learner_pathway/api/v1/urls.py @@ -6,8 +6,8 @@ router = routers.SimpleRouter() router.register(r'learner-pathway', views.LearnerPathwayViewSet) router.register(r'learner-pathway-step', views.LearnerPathwayStepViewSet) -router.register(r'learner-pathway-course', views.LearnerPathwayCourseViewSet) -router.register(r'learner-pathway-program', views.LearnerPathwayProgramViewSet) +router.register(r'learner-pathway-course', views.LearnerPathwayCourseViewSet, basename='learner-pathway-course') +router.register(r'learner-pathway-program', views.LearnerPathwayProgramViewSet, basename='learner-pathway-program') router.register(r'learner-pathway-block', views.LearnerPathwayBlocViewSet) urlpatterns = router.urls diff --git a/course_discovery/apps/learner_pathway/api/v1/views.py b/course_discovery/apps/learner_pathway/api/v1/views.py index 5e8031fb6a2..7b6657b504a 100644 --- a/course_discovery/apps/learner_pathway/api/v1/views.py +++ b/course_discovery/apps/learner_pathway/api/v1/views.py @@ -1,7 +1,7 @@ """ API Views for learner_pathway app. """ -from django.db.models import Q +from django.db.models import Prefetch, Q from django_filters.rest_framework import DjangoFilterBackend from rest_framework import status from rest_framework.decorators import action @@ -10,6 +10,9 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from course_discovery.apps.api.pagination import ProxiedPagination +from course_discovery.apps.api.utils import get_excluded_restriction_types +from course_discovery.apps.course_metadata.choices import CourseRunStatus +from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.learner_pathway import models from course_discovery.apps.learner_pathway.api import serializers from course_discovery.apps.learner_pathway.api.filters import PathwayUUIDFilter @@ -81,12 +84,23 @@ class LearnerPathwayCourseViewSet(ReadOnlyModelViewSet): lookup_field = 'uuid' serializer_class = serializers.LearnerPathwayCourseSerializer - queryset = models.LearnerPathwayCourse.objects.all() # Explicitly support PageNumberPagination and LimitOffsetPagination. Future # versions of this API should only support the system default, PageNumberPagination. pagination_class = ProxiedPagination + def get_queryset(self): + excluded_restriction_types = get_excluded_restriction_types(self.request) + return models.LearnerPathwayCourse.objects.prefetch_related( + Prefetch( + 'course__course_runs', + queryset=CourseRun.objects.filter( + status=CourseRunStatus.Published + ).exclude( + restricted_run__restriction_type__in=excluded_restriction_types + )) + ) + class LearnerPathwayProgramViewSet(ReadOnlyModelViewSet): """ @@ -95,12 +109,23 @@ class LearnerPathwayProgramViewSet(ReadOnlyModelViewSet): lookup_field = 'uuid' serializer_class = serializers.LearnerPathwayProgramSerializer - queryset = models.LearnerPathwayProgram.objects.all() # Explicitly support PageNumberPagination and LimitOffsetPagination. Future # versions of this API should only support the system default, PageNumberPagination. pagination_class = ProxiedPagination + def get_queryset(self): + excluded_restriction_types = get_excluded_restriction_types(self.request) + return models.LearnerPathwayProgram.objects.prefetch_related( + Prefetch( + 'program__courses__course_runs', + queryset=CourseRun.objects.filter( + status=CourseRunStatus.Published + ).exclude( + restricted_run__restriction_type__in=excluded_restriction_types + )) + ) + class LearnerPathwayBlocViewSet(ReadOnlyModelViewSet): """ diff --git a/course_discovery/apps/learner_pathway/models.py b/course_discovery/apps/learner_pathway/models.py index 7278f612a3f..a40832dbf58 100644 --- a/course_discovery/apps/learner_pathway/models.py +++ b/course_discovery/apps/learner_pathway/models.py @@ -299,23 +299,15 @@ def get_skills(self) -> [str]: return program_skills - def get_linked_courses_and_course_runs(self, excluded_restriction_types=None) -> [dict]: + def get_linked_courses_and_course_runs(self): """ Returns list of dict where each dict contains a course key linked with program and all its course runs """ - if excluded_restriction_types is None: - excluded_restriction_types = [] courses = [] for course in self.program.courses.all(): - course_runs = list( - course.course_runs.filter( - status=CourseRunStatus.Published - ).exclude( - restricted_run__restriction_type__in=excluded_restriction_types - ).values('key') - ) - courses.append({"key": course.key, "course_runs": course_runs}) + course_runs = [{'key': course_run.key} for course_run in course.course_runs.all()] + courses.append({'key': course.key, 'course_runs': course_runs}) return courses def __str__(self):