diff --git a/enterprise_catalog/apps/api/base/__init__.py b/enterprise_catalog/apps/api/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/enterprise_catalog/apps/api/base/tests/__init__.py b/enterprise_catalog/apps/api/base/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/enterprise_catalog/apps/api/base/tests/enterprise_customer_views.py b/enterprise_catalog/apps/api/base/tests/enterprise_customer_views.py new file mode 100644 index 00000000..dd5ad094 --- /dev/null +++ b/enterprise_catalog/apps/api/base/tests/enterprise_customer_views.py @@ -0,0 +1,62 @@ +from rest_framework.reverse import reverse + +from enterprise_catalog.apps.api.v1.tests.mixins import APITestMixin +from enterprise_catalog.apps.catalog.models import ( + CatalogQuery, + ContentMetadata, + EnterpriseCatalog, +) +from enterprise_catalog.apps.catalog.tests.factories import ( + EnterpriseCatalogFactory, +) + + +class BaseEnterpriseCustomerViewSetTests(APITestMixin): + """ + Tests for the EnterpriseCustomerViewSet + """ + def setUp(self): + super().setUp() + # clean up any stale test objects + CatalogQuery.objects.all().delete() + ContentMetadata.objects.all().delete() + EnterpriseCatalog.objects.all().delete() + + self.enterprise_catalog = EnterpriseCatalogFactory(enterprise_uuid=self.enterprise_uuid) + + # Set up catalog.has_learner_access permissions + self.set_up_catalog_learner() + + def tearDown(self): + super().tearDown() + # clean up any stale test objects + CatalogQuery.objects.all().delete() + ContentMetadata.objects.all().delete() + EnterpriseCatalog.objects.all().delete() + + def _get_contains_content_base_url(self, enterprise_uuid=None): + """ + Helper to construct the base url for the contains_content_items endpoint + """ + return reverse( + 'api:v1:enterprise-customer-contains-content-items', + kwargs={'enterprise_uuid': enterprise_uuid or self.enterprise_uuid}, + ) + + def _get_filter_content_base_url(self, enterprise_uuid=None): + """ + Helper to construct the base url for the filter_content_items endpoint + """ + return reverse( + 'api:v1:enterprise-customer-filter-content-items', + kwargs={'enterprise_uuid': enterprise_uuid or self.enterprise_uuid}, + ) + + def _get_generate_diff_base_url(self, enterprise_catalog_uuid=None): + """ + Helper to construct the base url for the catalog `generate_diff` endpoint + """ + return reverse( + 'api:v1:generate-catalog-diff', + kwargs={'uuid': enterprise_catalog_uuid or self.enterprise_catalog.uuid}, + ) diff --git a/enterprise_catalog/apps/api/v1/tests/test_enterprise_customer_views.py b/enterprise_catalog/apps/api/v1/tests/test_enterprise_customer_views.py index de738e47..8360b17a 100644 --- a/enterprise_catalog/apps/api/v1/tests/test_enterprise_customer_views.py +++ b/enterprise_catalog/apps/api/v1/tests/test_enterprise_customer_views.py @@ -5,74 +5,21 @@ import pytest import pytz from rest_framework import status -from rest_framework.reverse import reverse -from enterprise_catalog.apps.api.v1.tests.mixins import APITestMixin +from enterprise_catalog.apps.api.base.tests.enterprise_customer_views import BaseEnterpriseCustomerViewSetTests from enterprise_catalog.apps.catalog.constants import ( RESTRICTED_RUNS_ALLOWED_KEY, ) -from enterprise_catalog.apps.catalog.models import ( - CatalogQuery, - ContentMetadata, - EnterpriseCatalog, -) from enterprise_catalog.apps.catalog.tests.factories import ( ContentMetadataFactory, EnterpriseCatalogFactory, ) -class EnterpriseCustomerViewSetTests(APITestMixin): +class EnterpriseCustomerViewSetTests(BaseEnterpriseCustomerViewSetTests): """ Tests for the EnterpriseCustomerViewSet """ - - def setUp(self): - super().setUp() - # clean up any stale test objects - CatalogQuery.objects.all().delete() - ContentMetadata.objects.all().delete() - EnterpriseCatalog.objects.all().delete() - - self.enterprise_catalog = EnterpriseCatalogFactory(enterprise_uuid=self.enterprise_uuid) - - # Set up catalog.has_learner_access permissions - self.set_up_catalog_learner() - - def tearDown(self): - super().tearDown() - # clean up any stale test objects - CatalogQuery.objects.all().delete() - ContentMetadata.objects.all().delete() - EnterpriseCatalog.objects.all().delete() - - def _get_contains_content_base_url(self, enterprise_uuid=None): - """ - Helper to construct the base url for the contains_content_items endpoint - """ - return reverse( - 'api:v1:enterprise-customer-contains-content-items', - kwargs={'enterprise_uuid': enterprise_uuid or self.enterprise_uuid}, - ) - - def _get_filter_content_base_url(self, enterprise_uuid=None): - """ - Helper to construct the base url for the filter_content_items endpoint - """ - return reverse( - 'api:v1:enterprise-customer-filter-content-items', - kwargs={'enterprise_uuid': enterprise_uuid or self.enterprise_uuid}, - ) - - def _get_generate_diff_base_url(self, enterprise_catalog_uuid=None): - """ - Helper to construct the base url for the catalog `generate_diff` endpoint - """ - return reverse( - 'api:v1:generate-catalog-diff', - kwargs={'uuid': enterprise_catalog_uuid or self.enterprise_catalog.uuid}, - ) - def test_generate_diff_unauthorized_non_catalog_learner(self): """ Verify the generate_diff endpoint rejects users that are not catalog learners diff --git a/enterprise_catalog/apps/api/v1/views/enterprise_customer.py b/enterprise_catalog/apps/api/v1/views/enterprise_customer.py index 1cf6602e..cb7418cf 100644 --- a/enterprise_catalog/apps/api/v1/views/enterprise_customer.py +++ b/enterprise_catalog/apps/api/v1/views/enterprise_customer.py @@ -64,6 +64,18 @@ def get_permission_object(self): """ return self.kwargs.get('enterprise_uuid') + def filter_content_keys(self, catalog, content_keys): + return catalog.filter_content_keys(content_keys) + + def contains_content_keys(self, catalog, content_keys): + return catalog.contains_content_keys(content_keys) + + def get_metadata_by_uuid(self, catalog, content_uuid): + return catalog.content_metadata.filter(content_uuid=content_uuid) + + def get_metadata_by_content_key(self, catalog, content_key): + return catalog.get_matching_content(content_keys=[content_key]) + @method_decorator(require_at_least_one_query_parameter('course_run_ids', 'program_uuids')) @action(detail=True) def contains_content_items(self, request, enterprise_uuid, course_run_ids, program_uuids, **kwargs): @@ -105,9 +117,9 @@ def contains_content_items(self, request, enterprise_uuid, course_run_ids, progr any_catalog_contains_content_items = False catalogs_that_contain_course = [] + content_keys = requested_course_or_run_keys + program_uuids for catalog in customer_catalogs: - contains_content_items = catalog.contains_content_keys(requested_course_or_run_keys + program_uuids) - if contains_content_items: + if contains_content_items := self.contains_content_keys(catalog, content_keys): any_catalog_contains_content_items = True if not (get_catalogs_containing_specified_content_ids or get_catalog_list): # Break as soon as we find a catalog that contains the specified content @@ -136,8 +148,7 @@ def filter_content_items(self, request, enterprise_uuid, **kwargs): filtered_content_keys = set() for catalog in customer_catalogs: - items_included = catalog.filter_content_keys(content_keys) - if items_included: + if items_included := self.filter_content_keys(catalog, content_keys): filtered_content_keys = filtered_content_keys.union(items_included) response_data = { @@ -164,8 +175,7 @@ def get_metadata_item_serializer(self): # identifier is a valid UUID. content_uuid = uuid.UUID(content_identifier) for catalog in enterprise_catalogs: - content_with_uuid = catalog.content_metadata.filter(content_uuid=content_uuid) - if content_with_uuid: + if content_with_uuid := self.get_metadata_by_uuid(catalog, content_uuid): return ContentMetadataSerializer( content_with_uuid.first(), context={'enterprise_catalog': catalog, **serializer_context}, @@ -173,8 +183,7 @@ def get_metadata_item_serializer(self): except ValueError: # Otherwise, search for matching metadata as a content key for catalog in enterprise_catalogs: - content_with_key = catalog.get_matching_content(content_keys=[content_identifier]) - if content_with_key: + if content_with_key := self.get_metadata_by_content_key(catalog, content_identifier): return ContentMetadataSerializer( content_with_key.first(), context={'enterprise_catalog': catalog, **serializer_context}, diff --git a/enterprise_catalog/apps/api/v2/views/enterprise_catalog_get_content_metadata.py b/enterprise_catalog/apps/api/v2/views/enterprise_catalog_get_content_metadata.py index 520c9b51..5803ca99 100644 --- a/enterprise_catalog/apps/api/v2/views/enterprise_catalog_get_content_metadata.py +++ b/enterprise_catalog/apps/api/v2/views/enterprise_catalog_get_content_metadata.py @@ -1,4 +1,4 @@ -from asyncio.log import logger +import logging from enterprise_catalog.apps.api.v1.views.enterprise_catalog_get_content_metadata import ( EnterpriseCatalogGetContentMetadata, @@ -6,6 +6,9 @@ from enterprise_catalog.apps.api.v2.utils import is_any_course_run_active +logger = logging.getLogger(__name__) + + class EnterpriseCatalogGetContentMetadataV2(EnterpriseCatalogGetContentMetadata): """ View for retrieving all the content metadata associated with a catalog. diff --git a/enterprise_catalog/apps/api/v2/views/enterprise_customer.py b/enterprise_catalog/apps/api/v2/views/enterprise_customer.py new file mode 100644 index 00000000..32d37acd --- /dev/null +++ b/enterprise_catalog/apps/api/v2/views/enterprise_customer.py @@ -0,0 +1,25 @@ +import logging + +from enterprise_catalog.apps.api.v1.views.enterprise_customer import ( + EnterpriseCustomerViewSet, +) + + +logger = logging.getLogger(__name__) + + +class EnterpriseCustomerViewSetV2(EnterpriseCustomerViewSet): + """ + V2 views for content metadata and catalog-content inclusion for retrieving. + """ + def get_metadata_by_uuid(self, catalog, content_uuid): + return catalog.content_metadata_with_restricted.filter(content_uuid=content_uuid) + + def get_metadata_by_content_key(self, catalog, content_key): + return catalog.get_matching_content(content_keys=[content_key], include_restricted=True) + + def filter_content_keys(self, catalog, content_keys): + return catalog.filter_content_keys(content_keys, include_restricted=True) + + def contains_content_keys(self, catalog, content_keys): + return catalog.contains_content_keys(content_keys, include_restricted=True)