diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index d4873fa35..187ea9670 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -14,6 +14,7 @@ Added ----- - Expose ``status`` on ``collection`` and ``entity`` viewset and allow filtering and sorting by it +- Add generic filtering by related objects which respects permissions Changed ------- diff --git a/resolwe/flow/filters.py b/resolwe/flow/filters.py index 7cd2d60f2..5d83381ce 100644 --- a/resolwe/flow/filters.py +++ b/resolwe/flow/filters.py @@ -9,14 +9,15 @@ import re import types from copy import deepcopy -from functools import partial -from typing import Callable, Union +from dataclasses import dataclass +from functools import partial, reduce +from typing import Any, Optional, Union from django.contrib.auth import get_user_model from django.contrib.auth.models import Group from django.contrib.postgres.search import SearchQuery, SearchRank from django.core.exceptions import ValidationError -from django.db.models import Count, F, ForeignKey, Q, Subquery +from django.db.models import Count, F, ForeignKey, Model, Q, Subquery from django.db.models.query import QuerySet from django_filters import rest_framework as filters from django_filters.constants import EMPTY_VALUES @@ -239,18 +240,115 @@ class ExtendedFilter: """ - def __new__(mcs, name, bases, namespace): - """Inject extensions into the filter.""" - class_path = "{}.{}".format(namespace["__module__"], name) - for extension in composer.get_extensions(class_path): - for name in dir(extension): - if name.startswith("__"): - continue - namespace[name] = getattr(extension, name) + def __new__( + mcs: type[object], + name: str, + bases: tuple[type[object]], + namespace: dict[str, Any], + ): + """Inject filters.""" + + def inject_composer_extensions(namespace: dict[str, Any], name: str): + """Inject filters from composer extensions.""" + class_path = "{}.{}".format(namespace["__module__"], name) + for extension in composer.get_extensions(class_path): + for name in dir(extension): + if name.startswith("__"): + continue + namespace[name] = getattr(extension, name) + + def get_related_model_from_path(related_path: str) -> type[Model]: + """Get the related model from the path.""" + return reduce( + lambda Model, part: Model._meta.get_field(part).related_model, + related_path.split("__"), + namespace["Meta"].model, + ) + + def add_filters_with_permissions( + prefix: str, related_path: str, base_filter: filters.FilterSet + ): + """Add filters on related objects with permissions. + All the filters belonging to the BaseFilter are added with the given prefix. + """ + + def filter_permissions( + self, + qs: QuerySet, + value: str, + original_filter: filters.Filter, + original_model: Model, + ): + """Apply the filter and respect permissions.""" + + # Do not filter when value is empty. At least one of the values must be + # non-empty since form in the AnnotationValueFilter class requires it. + if value in EMPTY_VALUES: + return qs + + visible_objects = list( + original_filter.filter(original_model.objects.all(), value) + .filter_for_user(self.parent.request.user) + .values_list("pk", flat=True) + ) + return qs.filter(**{f"{related_path}__in": visible_objects}) + + # Add all filters from EntityFilter to namespaces before creating class. + BaseModel = get_related_model_from_path(related_path) + for filter_name, filter in base_filter.get_filters().items(): + new_filter_name = f"{prefix}__{filter_name}" + if filter_name == "id" or filter_name.startswith("id__"): + new_filter_name = f"{prefix}{filter_name[2:]}" + + filter_copy = deepcopy(filter) + filter_copy.field_name = f"{prefix}__{filter_copy.field_name}" + filter_method = partial( + filter_permissions, original_filter=filter, original_model=BaseModel + ) + # Bind the new_filter to filter instance and set it as new filter. + filter_copy.filter = types.MethodType(filter_method, filter_copy) + namespace[new_filter_name] = filter_copy + # If filter uses a method, add it to the namespace as well. + if filter_copy.method is not None: + namespace[filter_copy.method] = deepcopy( + getattr(base_filter, filter_copy.method) + ) + + def inject_related_permissions(namespace: dict): + """Parse the namespace and add filters with permissions.""" + for prefix, value in list(namespace.items()): + if isinstance(value, FilterRelatedWithPermissions): + namespace.pop(prefix) + add_filters_with_permissions( + prefix, value.related_path or prefix, value.BaseFilter + ) + + inject_related_permissions(namespace) + inject_composer_extensions(namespace, name) return super().__new__(mcs, name, bases, namespace) +@dataclass +class FilterRelatedWithPermissions: + """Base class for filters with permissions. + + The class using this must have FilterRelatedWithPermissionsMeta as its meta class. + The filter is then defined as: + + prefix = FilterRelatedWithPermissions(BaseFilter, related_path="related__path") + + All the filters from the BaseFilter (with the given prefix) are added and + permissions on the related objects are respected. + + The optional parameter `related__path` is used to specify the path of the related + object (if different from the prefix). + """ + + BaseFilter: type[filters.FilterSet] + related_path: Optional[str] = None + + class BaseResolweFilter( CheckQueryParamsMixin, filters.FilterSet, metaclass=ResolweFilterMetaclass ): @@ -551,64 +649,11 @@ def get_ordering(self, request, queryset, view): return self.get_default_ordering(view) -class AnnotationValueFieldMetaclass(ResolweFilterMetaclass): - """Add all entity filters prefixed with 'entity'.""" - - def __new__(mcs, name, bases, namespace): - """Inject extensions into the filter.""" - - def filter_permissions( - self, - qs: QuerySet, - value: str, - original_entity_filter: Callable[[QuerySet, str], QuerySet], - ): - """Respect permissions on entities.""" - # Do not filter when value is empty. At least one of the values must be - # non-empty since form in the AnnotationValueFilter class requires it. - if value in EMPTY_VALUES: - return qs - - # Filter the entities using the original entity filter and permissions. - visible_entities = list( - original_entity_filter(Entity.objects.all(), value) - .filter_for_user(self.parent.request.user) - .values_list("pk", flat=True) - ) - return qs.filter(**{f"{entity_path}__in": visible_entities}) - - entity_path = { - "AnnotationValueFilter": "entity", - "AnnotationFieldFilter": "values__entity", - }[name] - # Add all filters from EntityFilter to namespaces before creating class. - for filter_name, filter in EntityFilter.get_filters().items(): - new_name = f"entity__{filter_name}" - if filter_name == "id" or filter_name.startswith("id__"): - new_name = "entity" + filter_name[2:] - filter_copy = deepcopy(filter) - filter_copy.field_name = f"{entity_path}__{filter_copy.field_name}" - filter_method = partial( - filter_permissions, - original_entity_filter=filter.filter, - ) - # Bind the new_filter to filter instance and set it as new filter. - filter_copy.filter = types.MethodType(filter_method, filter_copy) - namespace[new_name] = filter_copy - # If filter uses a method, add it to the namespace as well. - if filter_copy.method is not None: - namespace[filter_copy.method] = deepcopy( - getattr(EntityFilter, filter_copy.method) - ) - - # Create class with added filters. - klass = ResolweFilterMetaclass.__new__(mcs, name, bases, namespace) - return klass - - -class AnnotationFieldFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass): +class AnnotationFieldFilter(BaseResolweFilter): """Filter the AnnotationField endpoint.""" + entity = FilterRelatedWithPermissions(EntityFilter, related_path="values__entity") + @classmethod def filter_for_field(cls, field, field_name, lookup_expr=None): """Add permission check for collections lookups. @@ -688,9 +733,10 @@ class Meta(BaseResolweFilter.Meta): } -class AnnotationValueFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass): +class AnnotationValueFilter(BaseResolweFilter): """Filter the AnnotationValue endpoint.""" + entity = FilterRelatedWithPermissions(EntityFilter) label = filters.CharFilter(method="filter_by_label") def filter_by_label(self, queryset: QuerySet, name: str, value: str):