Skip to content

Commit

Permalink
fix: adding sorting/filtering to members endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
kiram15 committed Nov 20, 2024
1 parent ad4d01c commit 2805fd8
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Unreleased
----------
* nothing unreleased

[4.33.1]
--------
* feat: Creating enterprise customer members endpoint for admin portal

[4.32.0]
--------
* feat: create DefaultEnterpriseEnrollmentRealization objects in bulk enrollment API, when applicable.
Expand Down
2 changes: 1 addition & 1 deletion enterprise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Your project description goes here.
"""

__version__ = "4.32.0"
__version__ = "4.33.1"
41 changes: 41 additions & 0 deletions enterprise/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,47 @@ def get_role_assignments(self, obj):
return None


class EnterpriseMembersSerializer(serializers.Serializer):
"""
Serializer for EnterpriseCustomerUser model with additions.
"""
class Meta:
model = models.EnterpriseCustomerUser
fields = (
'enterprise_customer_user',
'enrollments',
'full_name',
)

# enterprise_customer_user = UserSerializer(source="user", required=False, default=None)
enterprise_customer_user = serializers.SerializerMethodField()
enrollments = serializers.SerializerMethodField()

def get_enrollments(self, obj):
"""
Fetch all of user's enterprise enrollments
"""
if hasattr(obj, 'user_id'):
user_id = obj.user_id
enrollments = models.EnterpriseCourseEnrollment.objects.filter(
enterprise_customer_user=user_id,
)
return len(enrollments)
return 0

def get_enterprise_customer_user(self, obj):
"""
Return either the member's name and email if it's the case that the member is realized, otherwise just email
"""
if user := obj:
return {
"email": user[0],
"joined_org": user[1].strftime("%b %d, %Y"),
"name": user[2],
}
return None


class DefaultEnterpriseEnrollmentIntentionSerializer(serializers.ModelSerializer):
"""
Serializer for the DefaultEnterpriseEnrollmentIntention model.
Expand Down
6 changes: 6 additions & 0 deletions enterprise/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
enterprise_customer_branding_configuration,
enterprise_customer_catalog,
enterprise_customer_invite_key,
enterprise_customer_members,
enterprise_customer_reporting,
enterprise_customer_sso_configuration,
enterprise_customer_support,
Expand Down Expand Up @@ -211,6 +212,11 @@
),
name='enterprise-customer-support'
),
re_path(
r'^enterprise-customer-members/(?P<enterprise_uuid>[A-Za-z0-9-]+)$',
enterprise_customer_members.EnterpriseCustomerMembersViewSet.as_view({'get': 'get_members'}),
name='enterprise-customer-members'
),
]

urlpatterns += router.urls
115 changes: 115 additions & 0 deletions enterprise/api/v1/views/enterprise_customer_members.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Views for the ``enterprise-customer-members`` API endpoint.
"""

from collections import OrderedDict

from rest_framework import permissions, response, status
from rest_framework.pagination import PageNumberPagination

from django.core.exceptions import ValidationError
from django.db import connection

from enterprise import models
from enterprise.api.v1 import serializers
from enterprise.api.v1.views.base_views import EnterpriseReadOnlyModelViewSet
from enterprise.logging import getEnterpriseLogger

LOGGER = getEnterpriseLogger(__name__)


class EnterpriseCustomerMembersPaginator(PageNumberPagination):
"""Custom paginator for the enterprise customer members."""

page_size = 10

def get_paginated_response(self, data):
"""Return a paginated style `Response` object for the given output data."""
return response.Response(
OrderedDict(
[
("count", self.page.paginator.count),
("num_pages", self.page.paginator.num_pages),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("results", data),
]
)
)

def paginate_queryset(self, queryset, request, view=None):
"""
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
if isinstance(queryset, filter):
queryset = list(queryset)

return super().paginate_queryset(queryset, request, view)


class EnterpriseCustomerMembersViewSet(EnterpriseReadOnlyModelViewSet):
"""
API views for the ``enterprise-customer-members`` API endpoint.
"""
queryset = models.EnterpriseCustomerUser.objects.all()
serializer_class = serializers.EnterpriseMembersSerializer

permission_classes = (permissions.IsAuthenticated,)
paginator = EnterpriseCustomerMembersPaginator()

def get_members(self, request, *args, **kwargs):
"""
Get all members associated with that enterprise customer
"""
enterprise_uuid = kwargs.get("enterprise_uuid", None)
# Raw sql is picky about uuid format
uuid_no_dashes = str(enterprise_uuid).replace("-", "")
users = []
user_query = self.request.query_params.get("user_query", None)

# On logistration, the name field of auth_userprofile is populated, but if it's not
# filled in, we check the auth_user model for it's first/last name fields
# https://2u-internal.atlassian.net/wiki/spaces/ENGAGE/pages/747143186/Use+of+full+name+in+edX#Data-on-Name-Field
query = """
WITH users AS (
SELECT
au.email,
au.date_joined,
coalesce(NULLIF(aup.name, ''), concat(au.first_name, ' ', au.last_name)) as full_name
FROM enterprise_enterprisecustomeruser ecu
INNER JOIN auth_user as au on ecu.user_id = au.id
LEFT JOIN auth_userprofile as aup on au.id = aup.user_id
WHERE ecu.enterprise_customer_id = %s
) SELECT * FROM users {user_query_filter} ORDER BY full_name;
"""
try:
with connection.cursor() as cursor:
if user_query:
like_user_query = f"%{user_query}%"
sql_to_execute = query.format(
user_query_filter="WHERE full_name LIKE %s OR email LIKE %s"
)
cursor.execute(
sql_to_execute,
[uuid_no_dashes, like_user_query, like_user_query],
)
else:
sql_to_execute = query.format(user_query_filter="")
cursor.execute(sql_to_execute, [uuid_no_dashes])
users.extend(cursor.fetchall())

except ValidationError:
# did not find UUID match in either EnterpriseCustomerUser
return response.Response(
{"detail": "Could not find enterprise uuid {}".format(enterprise_uuid)},
status=status.HTTP_404_NOT_FOUND,
)

# paginate the queryset
users_page = self.paginator.paginate_queryset(users, request, view=self)

# serialize the paged dataset
serializer = serializers.EnterpriseMembersSerializer(users_page, many=True)
return self.paginator.get_paginated_response(serializer.data)
56 changes: 55 additions & 1 deletion tests/test_enterprise/api/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
EnterpriseCustomerReportingConfigurationSerializer,
EnterpriseCustomerSerializer,
EnterpriseCustomerUserReadOnlySerializer,
EnterpriseMembersSerializer,
EnterpriseUserSerializer,
ImmutableStateSerializer,
)
Expand Down Expand Up @@ -455,7 +456,7 @@ def setUp(self):

super().setUp()

# setup Enteprise Customer
# setup Enterprise Customer
self.user_1 = factories.UserFactory()
self.user_2 = factories.UserFactory()
self.enterprise_customer_user_1 = factories.EnterpriseCustomerUserFactory(user_id=self.user_1.id)
Expand Down Expand Up @@ -558,3 +559,56 @@ def test_serialize_pending_users(self):
serialized_pending_admin_user = serializer.data

self.assertEqual(expected_pending_admin_user, serialized_pending_admin_user)


class TestEnterpriseMembersSerializer(TestCase):
"""
Tests for EnterpriseMembersSerializer.
"""
def setUp(self):
super().setUp()

# setup Enterprise Customer
self.user_1 = factories.UserFactory()
self.user_2 = factories.UserFactory()
self.enterprise_customer_user_1 = factories.EnterpriseCustomerUserFactory(user_id=self.user_1.id)
self.enterprise_customer_user_2 = factories.EnterpriseCustomerUserFactory(user_id=self.user_2.id)
self.enterprise_customer_1 = self.enterprise_customer_user_1.enterprise_customer
self.enterprise_customer_2 = self.enterprise_customer_user_2.enterprise_customer

self.enrollment_1 = factories.EnterpriseCourseEnrollmentFactory(
enterprise_customer_user=self.enterprise_customer_user_1,
)
self.enrollment_2 = factories.EnterpriseCourseEnrollmentFactory(
enterprise_customer_user=self.enterprise_customer_user_1,
)
self.enrollment_3 = factories.EnterpriseCourseEnrollmentFactory(
enterprise_customer_user=self.enterprise_customer_user_2,
)

def test_serialize_users(self):
expected_user = {
'enrollments': 2,
'enterprise_customer_user': {
'email': self.user_1.email,
'joined_org': self.user_1.date_joined.strftime("%b %d, %Y"),
'name': (self.user_1.first_name + ' ' + self.user_1.last_name),
},
}
serializer = EnterpriseMembersSerializer(self.enterprise_customer_user_1)
serialized_user = serializer.data

self.assertEqual(serialized_user, expected_user)

expected_user_2 = {
'enrollments': 1,
'enterprise_customer_user': {
'email': self.user_2.email,
'joined_org': self.user_2.date_joined.strftime("%b %d, %Y"),
'name': self.user_2.first_name + ' ' + self.user_2.last_name,
},
}

serializer = EnterpriseMembersSerializer(self.enterprise_customer_user_2)
serialized_user = serializer.data
self.assertEqual(serialized_user, expected_user_2)
68 changes: 68 additions & 0 deletions tests/test_enterprise/api/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9974,6 +9974,74 @@ def test_list_users_filtered(self):
assert response.json().get('count') == 1


@ddt.ddt
@mark.django_db
class TestEnterpriseCustomerMembers(BaseTestEnterpriseAPIViews):
"""
Test enterprise customer members list endpoint
"""
ECM_ENDPOINT = 'enterprise-customer-members'
ECM_KWARG = 'enterprise_uuid'

def test_get_enterprise_org_members(self):
"""
Assert whether the response is valid.
"""
user_1 = factories.UserFactory(first_name="Rhaenyra", last_name="Targaryen")
user_2 = factories.UserFactory(first_name="Jace", last_name="Targaryen")
user_3 = factories.UserFactory(first_name="Alicent", last_name="Hightower")
user_4 = factories.UserFactory(first_name="Helaena", last_name="Targaryen")
user_5 = factories.UserFactory(first_name="Laenor", last_name="Velaryon")

enterprise_customer = factories.EnterpriseCustomerFactory(uuid=FAKE_UUIDS[0])
factories.EnterpriseCustomerUserFactory(
user_id=user_1.id,
enterprise_customer=enterprise_customer
)
factories.EnterpriseCustomerUserFactory(
user_id=user_2.id,
enterprise_customer=enterprise_customer
)
factories.EnterpriseCustomerUserFactory(
user_id=user_3.id,
enterprise_customer=enterprise_customer
)
factories.EnterpriseCustomerUserFactory(
user_id=user_4.id,
enterprise_customer=enterprise_customer
)
factories.EnterpriseCustomerUserFactory(
user_id=user_5.id,
enterprise_customer=enterprise_customer
)

# Test invalid UUID
url = reverse(self.ECM_ENDPOINT, kwargs={self.ECM_KWARG: 123})
response = self.client.get(settings.TEST_SERVER + url)
self.assertEqual(response.status_code, 404)

# Test valid UUID
url = reverse(self.ECM_ENDPOINT, kwargs={self.ECM_KWARG: enterprise_customer.uuid})
response = self.client.get(settings.TEST_SERVER + url)
data = response.json().get('results')

# list should be sorted alphabetically by name
self.assertEqual(data[0]['enterprise_customer_user']['name'], (user_3.first_name + ' ' + user_3.last_name))
self.assertEqual(data[1]['enterprise_customer_user']['name'], (user_4.first_name + ' ' + user_4.last_name))
self.assertEqual(data[2]['enterprise_customer_user']['name'], (user_2.first_name + ' ' + user_2.last_name))
self.assertEqual(data[3]['enterprise_customer_user']['name'], (user_5.first_name + ' ' + user_5.last_name))
self.assertEqual(data[4]['enterprise_customer_user']['name'], (user_1.first_name + ' ' + user_1.last_name))

# use user query to filter by name
name_query = f'?user_query={user_2.first_name}'
url = reverse(self.ECM_ENDPOINT, kwargs={self.ECM_KWARG: enterprise_customer.uuid})
url = url + name_query
response = self.client.get(settings.TEST_SERVER + url)
data = response.json().get('results')
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['enterprise_customer_user']['name'], (user_2.first_name + ' ' + user_2.last_name))


@ddt.ddt
@mark.django_db
class TestDefaultEnterpriseEnrollmentIntentionViewSet(BaseTestEnterpriseAPIViews):
Expand Down

0 comments on commit 2805fd8

Please sign in to comment.