Skip to content

Commit

Permalink
Add generic filtering which respects permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Dec 4, 2024
1 parent b44efb3 commit 6405014
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 68 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Added
-----
- Expose ``status`` on ``collection`` and ``entity`` viewset and allow
filtering and sorting by it
- Add generic filtering by related objects which respects permissions

Changed
-------
Expand Down
182 changes: 114 additions & 68 deletions resolwe/flow/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import re
import types
from copy import deepcopy
from functools import partial
from typing import Callable, Union
from dataclasses import dataclass
from functools import partial, reduce
from typing import Any, Optional, Union

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.contrib.postgres.search import SearchQuery, SearchRank
from django.core.exceptions import ValidationError
from django.db.models import Count, F, ForeignKey, Q, Subquery
from django.db.models import Count, F, ForeignKey, Model, Q, Subquery
from django.db.models.query import QuerySet
from django_filters import rest_framework as filters
from django_filters.constants import EMPTY_VALUES
Expand Down Expand Up @@ -239,18 +240,115 @@ class ExtendedFilter:
"""

def __new__(mcs, name, bases, namespace):
"""Inject extensions into the filter."""
class_path = "{}.{}".format(namespace["__module__"], name)
for extension in composer.get_extensions(class_path):
for name in dir(extension):
if name.startswith("__"):
continue
namespace[name] = getattr(extension, name)
def __new__(
mcs: type[object],
name: str,
bases: tuple[type[object]],
namespace: dict[str, Any],
):
"""Inject filters."""

def inject_composer_extensions(namespace: dict[str, Any], name: str):
"""Inject filters from composer extensions."""
class_path = "{}.{}".format(namespace["__module__"], name)
for extension in composer.get_extensions(class_path):
for name in dir(extension):
if name.startswith("__"):
continue
namespace[name] = getattr(extension, name)

def get_related_model_from_path(related_path: str) -> type[Model]:
"""Get the related model from the path."""
return reduce(
lambda Model, part: Model._meta.get_field(part).related_model,
related_path.split("__"),
namespace["Meta"].model,
)

def add_filters_with_permissions(
prefix: str, related_path: str, base_filter: filters.FilterSet
):
"""Add filters on related objects with permissions.
All the filters belonging to the BaseFilter are added with the given prefix.
"""

def filter_permissions(
self,
qs: QuerySet,
value: str,
original_filter: filters.Filter,
original_model: Model,
):
"""Apply the filter and respect permissions."""

# Do not filter when value is empty. At least one of the values must be
# non-empty since form in the AnnotationValueFilter class requires it.
if value in EMPTY_VALUES:
return qs

visible_objects = list(
original_filter.filter(original_model.objects.all(), value)
.filter_for_user(self.parent.request.user)
.values_list("pk", flat=True)
)
return qs.filter(**{f"{related_path}__in": visible_objects})

# Add all filters from EntityFilter to namespaces before creating class.
BaseModel = get_related_model_from_path(related_path)
for filter_name, filter in base_filter.get_filters().items():
new_filter_name = f"{prefix}__{filter_name}"
if filter_name == "id" or filter_name.startswith("id__"):
new_filter_name = f"{prefix}{filter_name[2:]}"

filter_copy = deepcopy(filter)
filter_copy.field_name = f"{prefix}__{filter_copy.field_name}"
filter_method = partial(
filter_permissions, original_filter=filter, original_model=BaseModel
)
# Bind the new_filter to filter instance and set it as new filter.
filter_copy.filter = types.MethodType(filter_method, filter_copy)
namespace[new_filter_name] = filter_copy
# If filter uses a method, add it to the namespace as well.
if filter_copy.method is not None:
namespace[filter_copy.method] = deepcopy(
getattr(base_filter, filter_copy.method)
)

def inject_related_permissions(namespace: dict):
"""Parse the namespace and add filters with permissions."""
for prefix, value in list(namespace.items()):
if isinstance(value, FilterRelatedWithPermissions):
namespace.pop(prefix)
add_filters_with_permissions(
prefix, value.related_path or prefix, value.BaseFilter
)

inject_related_permissions(namespace)
inject_composer_extensions(namespace, name)
return super().__new__(mcs, name, bases, namespace)


@dataclass
class FilterRelatedWithPermissions:
"""Base class for filters with permissions.
The class using this must have FilterRelatedWithPermissionsMeta as its meta class.
The filter is then defined as:
prefix = FilterRelatedWithPermissions(BaseFilter, related_path="related__path")
All the filters from the BaseFilter (with the given prefix) are added and
permissions on the related objects are respected.
The optional parameter `related__path` is used to specify the path of the related
object (if different from the prefix).
"""

BaseFilter: type[filters.FilterSet]
related_path: Optional[str] = None


class BaseResolweFilter(
CheckQueryParamsMixin, filters.FilterSet, metaclass=ResolweFilterMetaclass
):
Expand Down Expand Up @@ -551,64 +649,11 @@ def get_ordering(self, request, queryset, view):
return self.get_default_ordering(view)


class AnnotationValueFieldMetaclass(ResolweFilterMetaclass):
"""Add all entity filters prefixed with 'entity'."""

def __new__(mcs, name, bases, namespace):
"""Inject extensions into the filter."""

def filter_permissions(
self,
qs: QuerySet,
value: str,
original_entity_filter: Callable[[QuerySet, str], QuerySet],
):
"""Respect permissions on entities."""
# Do not filter when value is empty. At least one of the values must be
# non-empty since form in the AnnotationValueFilter class requires it.
if value in EMPTY_VALUES:
return qs

# Filter the entities using the original entity filter and permissions.
visible_entities = list(
original_entity_filter(Entity.objects.all(), value)
.filter_for_user(self.parent.request.user)
.values_list("pk", flat=True)
)
return qs.filter(**{f"{entity_path}__in": visible_entities})

entity_path = {
"AnnotationValueFilter": "entity",
"AnnotationFieldFilter": "values__entity",
}[name]
# Add all filters from EntityFilter to namespaces before creating class.
for filter_name, filter in EntityFilter.get_filters().items():
new_name = f"entity__{filter_name}"
if filter_name == "id" or filter_name.startswith("id__"):
new_name = "entity" + filter_name[2:]
filter_copy = deepcopy(filter)
filter_copy.field_name = f"{entity_path}__{filter_copy.field_name}"
filter_method = partial(
filter_permissions,
original_entity_filter=filter.filter,
)
# Bind the new_filter to filter instance and set it as new filter.
filter_copy.filter = types.MethodType(filter_method, filter_copy)
namespace[new_name] = filter_copy
# If filter uses a method, add it to the namespace as well.
if filter_copy.method is not None:
namespace[filter_copy.method] = deepcopy(
getattr(EntityFilter, filter_copy.method)
)

# Create class with added filters.
klass = ResolweFilterMetaclass.__new__(mcs, name, bases, namespace)
return klass


class AnnotationFieldFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass):
class AnnotationFieldFilter(BaseResolweFilter):
"""Filter the AnnotationField endpoint."""

entity = FilterRelatedWithPermissions(EntityFilter, related_path="values__entity")

@classmethod
def filter_for_field(cls, field, field_name, lookup_expr=None):
"""Add permission check for collections lookups.
Expand Down Expand Up @@ -688,9 +733,10 @@ class Meta(BaseResolweFilter.Meta):
}


class AnnotationValueFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass):
class AnnotationValueFilter(BaseResolweFilter):
"""Filter the AnnotationValue endpoint."""

entity = FilterRelatedWithPermissions(EntityFilter)
label = filters.CharFilter(method="filter_by_label")

def filter_by_label(self, queryset: QuerySet, name: str, value: str):
Expand Down

0 comments on commit 6405014

Please sign in to comment.