Skip to content

Commit

Permalink
institution-user metrics view (with basic filters)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaxelb committed Sep 9, 2024
1 parent dceb435 commit 36f51d7
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 9 deletions.
64 changes: 64 additions & 0 deletions api/base/elasticsearch_dsl_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations
import abc

import elasticsearch_dsl as edsl
from rest_framework import generics

from api.base.filters import FilterMixin
from api.base.views import JSONAPIBaseView


class ElasticsearchListView(FilterMixin, JSONAPIBaseView, generics.ListAPIView, abc.ABC):
'''use `elasticsearch_dsl.Search` as a queryset-analogue
'''
@property
@abc.abstractmethod
def elasticsearch_document_class(self) -> type[edsl.Document]:
...

@abc.abstractmethod
def get_default_search(self) -> edsl.Search:
...

###
# beware! rest_framework shenanigans below

filter_backends = () # filtering handled in-view to reuse logic from FilterMixin

# note: because elasticsearch_dsl.Search supports slicing and gives results when iterated on,
# it should work fine with default pagination!

# override rest_framework.generics.GenericAPIView
def get_queryset(self):
_search = self.get_default_search()
# using parsing logic from FilterMixin (oddly nested dict and all)
for _parsed_param in self.parse_query_params(self.request.query_params).values():
for _parsed_filter in _parsed_param.values():
_search = self.__add_search_filter(
_search,
field_name=_parsed_filter['source_field_name'],
operation=_parsed_filter['op'],
value=_parsed_filter['value'],
)
return _search

def __add_search_filter(
self,
search: edsl.Search,
field_name: str,
operation: str,
value: str,
) -> edsl.Search:
match operation: # operations from FilterMixin
case 'eq':
if value == '':
return search.exclude('exists', field=field_name)
return search.filter('term', **{field_name: value})
case 'ne':
if value == '':
return search.filter('exists', field=field_name)
return search.exclude('term', **{field_name: value})
#case ('contains', 'icontains'):
#case ('gt', 'gte', 'lt', 'lte'):
case _:
raise NotImplementedError
4 changes: 2 additions & 2 deletions api/base/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def _get_field_or_error(self, field_name):
serializer = self.get_serializer()
if field_name not in serializer.fields:
raise InvalidFilterError(detail=f"'{field_name}' is not a valid field for this endpoint.")
if field_name not in getattr(serializer, 'filterable_fields', set()):
raise InvalidFilterFieldError(parameter='filter', value=field_name)
# if field_name not in getattr(serializer, 'filterable_fields', set()):
# raise InvalidFilterFieldError(parameter='filter', value=field_name)
field = serializer.fields[field_name]
# You cannot filter on deprecated fields.
if isinstance(field, ShowIfVersion) and utils.is_deprecated(self.request.version, field.min_version, field.max_version):
Expand Down
18 changes: 16 additions & 2 deletions api/institutions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from osf.metrics import InstitutionProjectCounts
from osf.models import OSFUser, Node, Institution, Registration
from osf.metrics import UserInstitutionProjectCounts
from osf.metrics.reports import InstitutionalUserReport
from osf.utils import permissions as osf_permissions

from api.base import permissions as base_permissions
from api.base.filters import ListFilterMixin, FilterMixin
from api.base.elasticsearch_dsl_views import ElasticsearchListView
from api.base.filters import ListFilterMixin
from api.base.views import JSONAPIBaseView
from api.base.serializers import JSONAPISerializer
from api.base.utils import get_object_or_error, get_user_auth
Expand Down Expand Up @@ -528,7 +530,7 @@ def get_default_queryset(self):
return self._make_elasticsearch_results_filterable(search, id=institution._id, department=DEFAULT_ES_NULL_VALUE)


class _NewInstitutionUserMetricsList(InstitutionMixin, FilterMixin, JSONAPIBaseView):
class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView):
permission_classes = (
drf_permissions.IsAuthenticatedOrReadOnly,
base_permissions.TokenHasScope,
Expand All @@ -542,6 +544,18 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, FilterMixin, JSONAPIBaseV
view_name = 'institution-user-metrics'

serializer_class = NewInstitutionUserMetricsSerializer
elasticsearch_document_class = InstitutionalUserReport

def get_default_search(self):
_yearmonth = InstitutionalUserReport.most_recent_yearmonth()
if _yearmonth is None:
return []
_search = (
InstitutionalUserReport.search()
.filter('term', report_yearmonth=str(_yearmonth))
.filter('term', institution_id=self.get_institution()._id)
)
return _search


institution_user_metrics_list_view = view_toggled_by_feature_flag(
Expand Down
118 changes: 117 additions & 1 deletion api_tests/institutions/views/test_institution_user_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import csv
from io import StringIO
from random import random
from urllib.parse import urlencode

import pytest
from waffle.testutils import override_flag
Expand All @@ -14,10 +15,11 @@
)

from osf.metrics import UserInstitutionProjectCounts
from osf.metrics.reports import InstitutionalUserReport

@pytest.mark.es
@pytest.mark.django_db
class TestInstitutionUserMetricList:
class TestOldInstitutionUserMetricList:

@pytest.fixture(autouse=True)
def _waffled(self):
Expand Down Expand Up @@ -262,3 +264,117 @@ def test_filter_and_sort(self, app, url, user, user2, user3, admin, user4, popul
assert data[0]['attributes']['department'] == 'Biology dept'
assert data[1]['attributes']['department'] == 'N/A'
assert data[2]['attributes']['department'] == 'Psychology dept'


@pytest.mark.django_db
class TestNewInstitutionUserMetricList:
@pytest.fixture(autouse=True)
def _waffled(self):
with override_flag(osf.features.INSTITUTIONAL_DASHBOARD_2024, active=True):
yield # these tests apply only after institution dashboard improvements

@pytest.fixture()
def institution(self):
return InstitutionFactory()

@pytest.fixture()
def rando(self):
return AuthUserFactory()

@pytest.fixture()
def institutional_admin(self, institution):
_admin_user = AuthUserFactory()
institution.get_group('institutional_admins').user_set.add(_admin_user)
return _admin_user

@pytest.fixture()
def unshown_reports(self, institution):
# unshown because another institution
_another_institution = InstitutionFactory()
_report_factory('2024-08', _another_institution, user_id='nother_inst')
# unshown because old
_report_factory('2024-07', institution, user_id='old')

@pytest.fixture()
def reports(self, institution):
return [
_report_factory(
'2024-08', institution,
user_id='u_sparse',
storage_byte_count=53,
),
_report_factory(
'2024-08', institution,
user_id='u_orc',
orcid_id='5555-4444-3333-2222',
storage_byte_count=8277,
),
_report_factory(
'2024-08', institution,
user_id='u_blargl',
department_name='blargl',
storage_byte_count=34834834,
),
_report_factory(
'2024-08', institution,
user_id='u_orcomma',
orcid_id='4444-3333-2222-1111',
department_name='a department, or so, that happens, incidentally, to have commas',
storage_byte_count=736662999298,
),
]

@pytest.fixture()
def url(self, institution):
return f'/{API_BASE}institutions/{institution._id}/metrics/users/'

def test_anon(self, app, url):
_resp = app.get(url, expect_errors=True)
assert _resp.status_code == 401

def test_rando(self, app, url, rando):
_resp = app.get(url, auth=rando.auth, expect_errors=True)
assert _resp.status_code == 403

@pytest.mark.es
def test_get_empty(self, app, url, institutional_admin):
_resp = app.get(url, auth=institutional_admin.auth)
assert _resp.status_code == 200
assert _resp.json['data'] == []

@pytest.mark.es
def test_get_reports(self, app, url, institutional_admin, institution, reports, unshown_reports):
_resp = app.get(url, auth=institutional_admin.auth)
assert _resp.status_code == 200
assert len(_resp.json['data']) == len(reports)
_expected_user_ids = {_report.user_id for _report in reports}
assert _user_ids(_resp) == _expected_user_ids

@pytest.mark.es
def test_filter_reports(self, app, url, institutional_admin, institution, reports, unshown_reports):
for _query, _expected_user_ids in (
({'filter[department][eq]': 'nunavum'}, set()),
({'filter[department][eq]': 'blargl'}, {'u_blargl'}),
({'filter[department][eq]': 'a department, or so, that happens, incidentally, to have commas'}, {'u_orcomma'}),
({'filter[orcid_id][eq]': ''}, {'u_sparse', 'u_blargl'}),
({'filter[orcid_id][ne]': ''}, {'u_orc', 'u_orcomma'}),
):
_resp = app.get(f'{url}?{urlencode(_query)}', auth=institutional_admin.auth)
assert _resp.status_code == 200
assert _user_ids(_resp) == _expected_user_ids


def _user_ids(api_response):
return {
_datum['relationships']['user']['data']['id']
for _datum in api_response.json['data']
}

def _report_factory(yearmonth, institution, **kwargs):
_report = InstitutionalUserReport(
report_yearmonth=yearmonth,
institution_id=institution._id,
**kwargs,
)
_report.save(refresh=True)
return _report
25 changes: 23 additions & 2 deletions osf/metrics/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DailyReport(metrics.Metric):
There's something we'd like to know about every so often,
so let's regularly run a report and stash the results here.
"""
UNIQUE_TOGETHER_FIELDS = ('report_date',) # override in subclasses for multiple reports per day
UNIQUE_TOGETHER_FIELDS: tuple[str, ...] = ('report_date',) # override in subclasses for multiple reports per day

report_date = metrics.Date(format='strict_date', required=True)

Expand All @@ -46,6 +46,10 @@ def deserialize(self, data):
return YearMonth.from_str(data)
elif isinstance(data, (datetime.datetime, datetime.date)):
return YearMonth.from_date(data)
elif isinstance(data, int):
# elasticsearch stores dates in milliseconds since the unix epoch
_as_datetime = datetime.datetime.fromtimestamp(data // 1000)
return YearMonth.from_date(_as_datetime)
elif data is None:
return None
else:
Expand All @@ -67,7 +71,7 @@ def serialize(self, data):
class MonthlyReport(metrics.Metric):
"""MonthlyReport (abstract base for report-based metrics that run monthly)
"""
UNIQUE_TOGETHER_FIELDS = ('report_yearmonth',) # override in subclasses for multiple reports per month
UNIQUE_TOGETHER_FIELDS: tuple[str, ...] = ('report_yearmonth',) # override in subclasses for multiple reports per month

report_yearmonth = YearmonthField(required=True)

Expand All @@ -76,6 +80,23 @@ class Meta:
dynamic = metrics.MetaField('strict')
source = metrics.MetaField(enabled=True)

@classmethod
def most_recent_yearmonth(cls, base_search=None) -> YearMonth | None:
_search = base_search or cls.search()
_search = _search.update_from_dict({'size': 0}) # omit hits
_search.aggs.bucket(
'agg_most_recent_yearmonth',
'terms',
field='report_yearmonth',
order={'_key': 'desc'},
size=1,
)
_response = _search.execute()
if not _response.aggregations:
return None
(_bucket,) = _response.aggregations.agg_most_recent_yearmonth.buckets
return _bucket.key

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
assert 'report_yearmonth' in cls.UNIQUE_TOGETHER_FIELDS, f'MonthlyReport subclasses must have "report_yearmonth" in UNIQUE_TOGETHER_FIELDS (on {cls.__qualname__}, got {cls.UNIQUE_TOGETHER_FIELDS})'
Expand Down
5 changes: 3 additions & 2 deletions osf/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import dataclasses
import re
import datetime
Expand Down Expand Up @@ -27,12 +28,12 @@ class YearMonth:
YEARMONTH_RE: ClassVar[re.Pattern] = re.compile(r'(?P<year>\d{4})-(?P<month>\d{2})')

@classmethod
def from_date(cls, date):
def from_date(cls, date: datetime.date) -> YearMonth:
assert isinstance(date, (datetime.datetime, datetime.date))
return cls(date.year, date.month)

@classmethod
def from_str(cls, input_str):
def from_str(cls, input_str: str) -> YearMonth:
match = cls.YEARMONTH_RE.fullmatch(input_str)
if match:
return cls(
Expand Down

0 comments on commit 36f51d7

Please sign in to comment.