Skip to content

Commit

Permalink
feat: Add improvements to populate_product_catalog_command
Browse files Browse the repository at this point in the history
  • Loading branch information
AfaqShuaib09 committed Sep 5, 2024
1 parent ec8fc50 commit a6f707d
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from django.conf import settings
from django.core.management import BaseCommand, CommandError
from django.db.models import Prefetch
from django.db.models import Count, Prefetch, Q

from course_discovery.apps.course_metadata.gspread_client import GspreadClient
from course_discovery.apps.course_metadata.models import Course, CourseType, Program, SubjectTranslation
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_products(self, product_type, product_source):
]

if product_type in ['executive_education', 'bootcamp', 'ocm_course']:
queryset = Course.objects.available()
queryset = Course.objects.available().select_related('partner', 'type')

if product_type == 'ocm_course':
queryset = queryset.filter(type__slug__in=ocm_course_catalog_types)
Expand All @@ -88,6 +88,10 @@ def get_products(self, product_type, product_source):
if product_source:
queryset = queryset.filter(product_source__slug=product_source)

queryset = queryset.annotate(
num_orgs=Count('authoring_organizations')
).filter(Q(num_orgs__gt=0) & Q(image__isnull=False) & ~Q(image=''))

# Prefetch Spanish translations of subjects
subject_translations = Prefetch(
'subjects__translations',
Expand All @@ -106,6 +110,10 @@ def get_products(self, product_type, product_source):
if product_source:
queryset = queryset.filter(product_source__slug=product_source)

queryset = queryset.annotate(
num_orgs=Count('authoring_organizations')
).filter(Q(num_orgs__gt=0) & Q(card_image__isnull=False) & ~Q(card_image=''))

subject_translations = Prefetch(
'courses__subjects__translations',
queryset=SubjectTranslation.objects.filter(language_code='es'),
Expand Down Expand Up @@ -137,7 +145,7 @@ def get_transformed_data(self, product, product_type):
authoring_orgs = product.authoring_organizations.all()

data = {
"UUID": str(product.uuid),
"UUID": str(product.uuid.hex),
"Title": product.title,
"Organizations Name": ", ".join(org.name for org in authoring_orgs),
"Organizations Logo": ", ".join(org.logo_image.url for org in authoring_orgs if org.logo_image),
Expand All @@ -151,7 +159,7 @@ def get_transformed_data(self, product, product_type):
translation.name for subject in product.subjects.all()
for translation in subject.spanish_translations
),
"Languages": product.languages_codes,
"Languages": product.languages_codes(),
"Marketing Image": product.image.url if product.image else "",
})
elif product_type == 'degree':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import csv
from tempfile import NamedTemporaryFile

import factory
import mock
from django.core.management import CommandError, call_command
from django.test import TestCase
Expand All @@ -12,15 +13,16 @@
from course_discovery.apps.course_metadata.management.commands.populate_product_catalog import Command
from course_discovery.apps.course_metadata.models import Course, CourseType, ProgramType
from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, CourseTypeFactory, DegreeFactory, PartnerFactory, ProgramTypeFactory, SeatFactory,
SourceFactory
CourseFactory, CourseRunFactory, CourseTypeFactory, DegreeFactory, OrganizationFactory, PartnerFactory,
ProgramTypeFactory, SeatFactory, SourceFactory
)


class PopulateProductCatalogCommandTests(TestCase):
def setUp(self):
super().setUp()
self.partner = PartnerFactory.create()
self.organization = OrganizationFactory(partner=self.partner)
self.course_type = CourseTypeFactory(slug=CourseType.AUDIT)
self.source = SourceFactory.create(slug="edx")
self.courses = CourseFactory.create_batch(
Expand All @@ -29,6 +31,7 @@ def setUp(self):
partner=self.partner,
additional_metadata=None,
type=self.course_type,
authoring_organizations=[self.organization]
)
self.course_run = CourseRunFactory(
course=Course.objects.all()[0],
Expand All @@ -45,6 +48,8 @@ def setUp(self):
partner=self.partner,
additional_metadata=None,
type=self.program_type,
authoring_organizations=[self.organization],
card_image=factory.django.ImageField()
)

def test_populate_product_catalog(self):
Expand Down Expand Up @@ -94,11 +99,11 @@ def test_populate_product_catalog_for_degrees(self):

for degree in self.degrees:
with self.subTest(degree=degree):
matching_rows = [row for row in rows if row["UUID"] == str(degree.uuid)]
matching_rows = [row for row in rows if row["UUID"] == str(degree.uuid.hex)]
self.assertEqual(len(matching_rows), 1)

row = matching_rows[0]
self.assertEqual(row["UUID"], str(degree.uuid))
self.assertEqual(row["UUID"], str(degree.uuid.hex))
self.assertEqual(row["Title"], degree.title)
self.assertIn("Organizations Name", row)
self.assertIn("Organizations Logo", row)
Expand Down Expand Up @@ -161,7 +166,19 @@ def test_populate_product_catalog_excludes_non_marketable_degrees(self):
type=self.program_type,
status=ProgramStatus.Active,
marketing_slug="valid-marketing-slug",
title="Marketable Degree"
title="Marketable Degree",
authoring_organizations=[self.organization],
card_image=factory.django.ImageField()
)

marketable_degree_2 = DegreeFactory.create(
product_source=self.source,
partner=self.partner,
additional_metadata=None,
type=self.program_type,
status=ProgramStatus.Active,
marketing_slug="valid-marketing-slug",
title="Marketable Degree 2 - Without Authoring Orgs"
)

with NamedTemporaryFile() as output_csv:
Expand All @@ -180,15 +197,23 @@ def test_populate_product_catalog_excludes_non_marketable_degrees(self):
# Check that non-marketable degrees are not in the CSV
for degree in non_marketable_degrees:
with self.subTest(degree=degree):
matching_rows = [
row for row in rows if row["UUID"] == str(degree.uuid)
]
self.assertEqual(len(matching_rows), 0,
f"Non-marketable degree '{degree.title}' should not be in the CSV")
matching_rows = [row for row in rows if row["UUID"] == str(degree.uuid)]
self.assertEqual(
len(matching_rows), 0, f"Non-marketable degree '{degree.title}' should not be in the CSV"
)

# Check that the marketable degree without authoring orgs is not in the CSV
matching_rows = [
row for row in rows if row["UUID"] == str(marketable_degree_2.uuid.hex)
]
self.assertEqual(
len(matching_rows), 0,
f"Marketable degree '{marketable_degree_2.title}' without authoring orgs should not be in the CSV"
)

# Check that the marketable degree is in the CSV
matching_rows = [
row for row in rows if row["UUID"] == str(marketable_degree.uuid)
row for row in rows if row["UUID"] == str(marketable_degree.uuid.hex)
]
self.assertEqual(len(matching_rows), 1,
f"Marketable degree '{marketable_degree.title}' should be in the CSV")
Expand Down Expand Up @@ -244,7 +269,7 @@ def test_get_transformed_data(self):
product_authoring_orgs = product.authoring_organizations.all()
transformed_prod_data = command.get_transformed_data(product, "ocm_course")
assert transformed_prod_data == {
"UUID": str(product.uuid),
"UUID": str(product.uuid.hex),
"Title": product.title,
"Organizations Name": ", ".join(
org.name for org in product_authoring_orgs
Expand All @@ -257,7 +282,7 @@ def test_get_transformed_data(self):
"Organizations Abbr": ", ".join(
org.key for org in product_authoring_orgs
),
"Languages": product.languages_codes,
"Languages": product.languages_codes(),
"Subjects": ", ".join(subject.name for subject in product.subjects.all()),
"Subjects Spanish": ", ".join(
translation.name
Expand All @@ -277,7 +302,7 @@ def test_get_transformed_data_for_degree(self):
product_authoring_orgs = product.authoring_organizations.all()
transformed_prod_data = command.get_transformed_data(product, "degree")
assert transformed_prod_data == {
"UUID": str(product.uuid),
"UUID": str(product.uuid.hex),
"Title": product.title,
"Organizations Name": ", ".join(org.name for org in product_authoring_orgs),
"Organizations Logo": ", ".join(
Expand Down
22 changes: 18 additions & 4 deletions course_discovery/apps/course_metadata/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,14 +1709,28 @@ def languages(self, exclude_inactive_runs=False):
if course_run.language is not None
})

@property
def languages_codes(self):
def languages_codes(self, exclude_inactive_runs=False):
"""
Returns a string of languages codes used in this course. The languages codes are separated by comma.
This property will ignore restricted runs and course runs with no language set.
Arguments:
exclude_inactive_runs (bool): whether to exclude inactive runs
"""
filtered_course_runs = self.active_course_runs.filter(language__isnull=False, restricted_run__isnull=True)
return ','.join(course_run.language.code for course_run in filtered_course_runs)
if exclude_inactive_runs:
language_codes = set(
course_run.language.code for course_run in self.active_course_runs.filter(
language__isnull=False, restricted_run__isnull=True
)
)
else:
language_codes = set(
course_run.language.code for course_run in self.course_runs.filter(
language__isnull=False, restricted_run__isnull=True
)
)

return ", ".join(sorted(language_codes))

@property
def first_enrollable_paid_seat_price(self):
Expand Down
24 changes: 24 additions & 0 deletions course_discovery/apps/course_metadata/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ def test_image_url(self):
course.image = None
assert course.image_url == course.card_image_url

def test_language_codes(self):
partner = factories.PartnerFactory.create()
source = factories.SourceFactory.create(slug="edx")
course = factories.CourseFactory(
product_source=source,
partner=partner,
additional_metadata=None,
)
LanguageTag.objects.create(code='en', name='English')
course_run = CourseRunFactory(
course=Course.objects.all()[0],
status=CourseRunStatus.Published,
language=LanguageTag.objects.get(code='en')
)
SeatFactory.create(course_run=course_run)
CourseRunFactory(
course=Course.objects.all()[0],
status=CourseRunStatus.Unpublished,
language=LanguageTag.objects.get(code='es'),
enrollment_end=datetime.datetime.now() - datetime.timedelta(days=5)
)
assert course.languages_codes() == 'en, es'
assert course.languages_codes(exclude_inactive_runs=True) == 'en'

def test_validate_history_created_only_on_change(self):
"""
Validate that course history object would be created if the object is changed otherwise not.
Expand Down

0 comments on commit a6f707d

Please sign in to comment.