Skip to content

Commit

Permalink
sortable institution-users
Browse files Browse the repository at this point in the history
  • Loading branch information
aaxelb committed Sep 10, 2024
1 parent 48da0fb commit 9d13299
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
32 changes: 29 additions & 3 deletions api/base/elasticsearch_dsl_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import elasticsearch_dsl as edsl
from rest_framework import generics
from rest_framework import exceptions as drf_exceptions
from rest_framework.settings import api_settings as drf_settings

from api.base.filters import FilterMixin
from api.base.views import JSONAPIBaseView
Expand All @@ -11,6 +13,9 @@
class ElasticsearchListView(FilterMixin, JSONAPIBaseView, generics.ListAPIView, abc.ABC):
'''abstract view class using `elasticsearch_dsl.Search` as a queryset-analogue
'''
default_ordering: str | None = None
ordering_fields: frozenset[str] = frozenset()

@abc.abstractmethod
def get_default_search(self) -> edsl.Search:
...
Expand All @@ -19,8 +24,8 @@ def get_default_search(self) -> edsl.Search:
# beware! inheritance shenanigans below

# override FilterMixin to disable all operators besides 'eq' and 'ne'
MATCHABLE_FIELDS = () # type: ignore[assignment]
COMPARABLE_FIELDS = () # type: ignore[assignment]
MATCHABLE_FIELDS = ()
COMPARABLE_FIELDS = ()
DEFAULT_OPERATOR_OVERRIDES = {}
# (if you want to add fulltext-search or range-filter support, remove the override
# and update `__add_search_filter` to handle those operators -- tho note that the
Expand All @@ -35,7 +40,7 @@ def get_default_search(self) -> edsl.Search:

# override rest_framework.generics.GenericAPIView
def get_queryset(self):
_search = self.get_default_search()
_search = self.__add_sort(self.get_default_search())
# using parsing logic from FilterMixin (oddly nested dict and all)
for _parsed_param in self.parse_query_params(self.request.query_params).values():
for _parsed_filter in _parsed_param.values():
Expand All @@ -50,6 +55,27 @@ def get_queryset(self):
###
# private methods

def __add_sort(self, search: edsl.Search) -> edsl.Search:
_elastic_sort = self.__get_elastic_sort()
return (search if _elastic_sort is None else search.sort(_elastic_sort))

def __get_elastic_sort(self) -> str | None:
_sort_param = self.request.query_params.get(drf_settings.ORDERING_PARAM, self.default_ordering)
if not _sort_param:
return None
_sort_field, _ascending = (
(_sort_param[1:], False)
if _sort_param.startswith('-')
else (_sort_param, True)
)
if _sort_field not in self.ordering_fields:
raise drf_exceptions.ValidationError(
f'invalid value for {drf_settings.ORDERING_PARAM} query param (valid values: {", ".join(self.ordering_fields)})'
)
_serializer_field = self.get_serializer().fields[_sort_field]
_elastic_sort_field = _serializer_field.source
return (_elastic_sort_field if _ascending else f'-{_elastic_sort_field}')

def __add_search_filter(
self,
search: edsl.Search,
Expand Down
15 changes: 15 additions & 0 deletions api/institutions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,21 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView):

serializer_class = NewInstitutionUserMetricsSerializer

default_ordering = '-storage_usage_bytes'
ordering_fields = frozenset((
'user_name',
'department',
'month_last_login',
'account_creation_date',
'public_projects',
'private_projects',
'public_registration_count',
'embargoed_registration_count',
'published_preprint_count',
'public_file_count',
'storage_byte_count',
))

def get_default_search(self):
_yearmonth = InstitutionalUserReport.most_recent_yearmonth()
if _yearmonth is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,15 @@ def test_filter_reports(self, app, url, institutional_admin, institution, report
assert _resp.status_code == 200
assert set(_user_ids(_resp)) == _expected_user_ids

@pytest.mark.es
def test_sort_reports(self, app, url, institutional_admin, institution, reports, unshown_reports):
for _query, _expected_user_id_list in (
({'sort': 'storage_byte_count'}, ['u_sparse', 'u_orc', 'u_blargl', 'u_orcomma']),
({'sort': '-storage_byte_count'}, ['u_orcomma', 'u_blargl', 'u_orc', 'u_sparse']),
):
_resp = app.get(f'{url}?{urlencode(_query)}', auth=institutional_admin.auth)
assert _resp.status_code == 200
assert list(_user_ids(_resp)) == _expected_user_id_list

def _user_ids(api_response):
for _datum in api_response.json['data']:
Expand Down

0 comments on commit 9d13299

Please sign in to comment.