diff --git a/api/base/elasticsearch_dsl_views.py b/api/base/elasticsearch_dsl_views.py index 54d6afb885e..0c9388be126 100644 --- a/api/base/elasticsearch_dsl_views.py +++ b/api/base/elasticsearch_dsl_views.py @@ -3,6 +3,8 @@ import elasticsearch_dsl as edsl from rest_framework import generics +from rest_framework import exceptions as drf_exceptions +from rest_framework.settings import api_settings as drf_settings from api.base.filters import FilterMixin from api.base.views import JSONAPIBaseView @@ -11,6 +13,9 @@ class ElasticsearchListView(FilterMixin, JSONAPIBaseView, generics.ListAPIView, abc.ABC): '''abstract view class using `elasticsearch_dsl.Search` as a queryset-analogue ''' + default_ordering: str | None = None + ordering_fields: frozenset[str] = frozenset() + @abc.abstractmethod def get_default_search(self) -> edsl.Search: ... @@ -19,8 +24,8 @@ def get_default_search(self) -> edsl.Search: # beware! inheritance shenanigans below # override FilterMixin to disable all operators besides 'eq' and 'ne' - MATCHABLE_FIELDS = () # type: ignore[assignment] - COMPARABLE_FIELDS = () # type: ignore[assignment] + MATCHABLE_FIELDS = () + COMPARABLE_FIELDS = () DEFAULT_OPERATOR_OVERRIDES = {} # (if you want to add fulltext-search or range-filter support, remove the override # and update `__add_search_filter` to handle those operators -- tho note that the @@ -35,7 +40,7 @@ def get_default_search(self) -> edsl.Search: # override rest_framework.generics.GenericAPIView def get_queryset(self): - _search = self.get_default_search() + _search = self.__add_sort(self.get_default_search()) # using parsing logic from FilterMixin (oddly nested dict and all) for _parsed_param in self.parse_query_params(self.request.query_params).values(): for _parsed_filter in _parsed_param.values(): @@ -50,6 +55,27 @@ def get_queryset(self): ### # private methods + def __add_sort(self, search: edsl.Search) -> edsl.Search: + _elastic_sort = self.__get_elastic_sort() + return (search if _elastic_sort is None else search.sort(_elastic_sort)) + + def __get_elastic_sort(self) -> str | None: + _sort_param = self.request.query_params.get(drf_settings.ORDERING_PARAM, self.default_ordering) + if not _sort_param: + return None + _sort_field, _ascending = ( + (_sort_param[1:], False) + if _sort_param.startswith('-') + else (_sort_param, True) + ) + if _sort_field not in self.ordering_fields: + raise drf_exceptions.ValidationError( + f'invalid value for {drf_settings.ORDERING_PARAM} query param (valid values: {", ".join(self.ordering_fields)})' + ) + _serializer_field = self.get_serializer().fields[_sort_field] + _elastic_sort_field = _serializer_field.source + return (_elastic_sort_field if _ascending else f'-{_elastic_sort_field}') + def __add_search_filter( self, search: edsl.Search, diff --git a/api/institutions/views.py b/api/institutions/views.py index 8abfe9e48a8..d2ed75df937 100644 --- a/api/institutions/views.py +++ b/api/institutions/views.py @@ -545,6 +545,21 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView): serializer_class = NewInstitutionUserMetricsSerializer + default_ordering = '-storage_usage_bytes' + ordering_fields = frozenset(( + 'user_name', + 'department', + 'month_last_login', + 'account_creation_date', + 'public_projects', + 'private_projects', + 'public_registration_count', + 'embargoed_registration_count', + 'published_preprint_count', + 'public_file_count', + 'storage_byte_count', + )) + def get_default_search(self): _yearmonth = InstitutionalUserReport.most_recent_yearmonth() if _yearmonth is None: diff --git a/api_tests/institutions/views/test_institution_user_metric_list.py b/api_tests/institutions/views/test_institution_user_metric_list.py index abee2c68f76..f3d20a3ba66 100644 --- a/api_tests/institutions/views/test_institution_user_metric_list.py +++ b/api_tests/institutions/views/test_institution_user_metric_list.py @@ -386,6 +386,15 @@ def test_filter_reports(self, app, url, institutional_admin, institution, report assert _resp.status_code == 200 assert set(_user_ids(_resp)) == _expected_user_ids + @pytest.mark.es + def test_sort_reports(self, app, url, institutional_admin, institution, reports, unshown_reports): + for _query, _expected_user_id_list in ( + ({'sort': 'storage_byte_count'}, ['u_sparse', 'u_orc', 'u_blargl', 'u_orcomma']), + ({'sort': '-storage_byte_count'}, ['u_orcomma', 'u_blargl', 'u_orc', 'u_sparse']), + ): + _resp = app.get(f'{url}?{urlencode(_query)}', auth=institutional_admin.auth) + assert _resp.status_code == 200 + assert list(_user_ids(_resp)) == _expected_user_id_list def _user_ids(api_response): for _datum in api_response.json['data']: