Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic filtering which respects permissions #1175

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Added
-----
- Expose ``status`` on ``collection`` and ``entity`` viewset and allow
filtering and sorting by it
- Add generic filtering by related objects which respects permissions

Changed
-------
Expand Down
182 changes: 114 additions & 68 deletions resolwe/flow/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import re
import types
from copy import deepcopy
from functools import partial
from typing import Callable, Union
from dataclasses import dataclass
from functools import partial, reduce
from typing import Any, Optional, Union

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.contrib.postgres.search import SearchQuery, SearchRank
from django.core.exceptions import ValidationError
from django.db.models import Count, F, ForeignKey, Q, Subquery
from django.db.models import Count, F, ForeignKey, Model, Q, Subquery
from django.db.models.query import QuerySet
from django_filters import rest_framework as filters
from django_filters.constants import EMPTY_VALUES
Expand Down Expand Up @@ -239,18 +240,115 @@ class ExtendedFilter:

"""

def __new__(mcs, name, bases, namespace):
"""Inject extensions into the filter."""
class_path = "{}.{}".format(namespace["__module__"], name)
for extension in composer.get_extensions(class_path):
for name in dir(extension):
if name.startswith("__"):
continue
namespace[name] = getattr(extension, name)
def __new__(
mcs: type[object],
name: str,
bases: tuple[type[object]],
namespace: dict[str, Any],
):
"""Inject filters."""

def inject_composer_extensions(namespace: dict[str, Any], name: str):
"""Inject filters from composer extensions."""
class_path = "{}.{}".format(namespace["__module__"], name)
for extension in composer.get_extensions(class_path):
for name in dir(extension):
if name.startswith("__"):
continue
namespace[name] = getattr(extension, name)

def get_related_model_from_path(related_path: str) -> type[Model]:
"""Get the related model from the path."""
return reduce(
lambda Model, part: Model._meta.get_field(part).related_model,
related_path.split("__"),
namespace["Meta"].model,
)

def add_filters_with_permissions(
prefix: str, related_path: str, base_filter: filters.FilterSet
):
"""Add filters on related objects with permissions.

All the filters belonging to the BaseFilter are added with the given prefix.
"""

def filter_permissions(
self,
qs: QuerySet,
value: str,
original_filter: filters.Filter,
original_model: Model,
):
"""Apply the filter and respect permissions."""

# Do not filter when value is empty. At least one of the values must be
# non-empty since form in the AnnotationValueFilter class requires it.
if value in EMPTY_VALUES:
return qs

visible_objects = list(
original_filter.filter(original_model.objects.all(), value)
.filter_for_user(self.parent.request.user)
.values_list("pk", flat=True)
)
return qs.filter(**{f"{related_path}__in": visible_objects})

# Add all filters from EntityFilter to namespaces before creating class.
BaseModel = get_related_model_from_path(related_path)
for filter_name, filter in base_filter.get_filters().items():
new_filter_name = f"{prefix}__{filter_name}"
if filter_name == "id" or filter_name.startswith("id__"):
new_filter_name = f"{prefix}{filter_name[2:]}"

filter_copy = deepcopy(filter)
filter_copy.field_name = f"{prefix}__{filter_copy.field_name}"
filter_method = partial(
filter_permissions, original_filter=filter, original_model=BaseModel
)
# Bind the new_filter to filter instance and set it as new filter.
filter_copy.filter = types.MethodType(filter_method, filter_copy)
namespace[new_filter_name] = filter_copy
# If filter uses a method, add it to the namespace as well.
if filter_copy.method is not None:
namespace[filter_copy.method] = deepcopy(
getattr(base_filter, filter_copy.method)
)

def inject_related_permissions(namespace: dict):
"""Parse the namespace and add filters with permissions."""
for prefix, value in list(namespace.items()):
if isinstance(value, FilterRelatedWithPermissions):
namespace.pop(prefix)
add_filters_with_permissions(
prefix, value.related_path or prefix, value.BaseFilter
)

inject_related_permissions(namespace)
inject_composer_extensions(namespace, name)
return super().__new__(mcs, name, bases, namespace)


@dataclass
class FilterRelatedWithPermissions:
"""Base class for filters with permissions.

The class using this must have FilterRelatedWithPermissionsMeta as its meta class.
The filter is then defined as:

prefix = FilterRelatedWithPermissions(BaseFilter, related_path="related__path")

All the filters from the BaseFilter (with the given prefix) are added and
permissions on the related objects are respected.

The optional parameter `related__path` is used to specify the path of the related
object (if different from the prefix).
"""

BaseFilter: type[filters.FilterSet]
related_path: Optional[str] = None


class BaseResolweFilter(
CheckQueryParamsMixin, filters.FilterSet, metaclass=ResolweFilterMetaclass
):
Expand Down Expand Up @@ -551,64 +649,11 @@ def get_ordering(self, request, queryset, view):
return self.get_default_ordering(view)


class AnnotationValueFieldMetaclass(ResolweFilterMetaclass):
"""Add all entity filters prefixed with 'entity'."""

def __new__(mcs, name, bases, namespace):
"""Inject extensions into the filter."""

def filter_permissions(
self,
qs: QuerySet,
value: str,
original_entity_filter: Callable[[QuerySet, str], QuerySet],
):
"""Respect permissions on entities."""
# Do not filter when value is empty. At least one of the values must be
# non-empty since form in the AnnotationValueFilter class requires it.
if value in EMPTY_VALUES:
return qs

# Filter the entities using the original entity filter and permissions.
visible_entities = list(
original_entity_filter(Entity.objects.all(), value)
.filter_for_user(self.parent.request.user)
.values_list("pk", flat=True)
)
return qs.filter(**{f"{entity_path}__in": visible_entities})

entity_path = {
"AnnotationValueFilter": "entity",
"AnnotationFieldFilter": "values__entity",
}[name]
# Add all filters from EntityFilter to namespaces before creating class.
for filter_name, filter in EntityFilter.get_filters().items():
new_name = f"entity__{filter_name}"
if filter_name == "id" or filter_name.startswith("id__"):
new_name = "entity" + filter_name[2:]
filter_copy = deepcopy(filter)
filter_copy.field_name = f"{entity_path}__{filter_copy.field_name}"
filter_method = partial(
filter_permissions,
original_entity_filter=filter.filter,
)
# Bind the new_filter to filter instance and set it as new filter.
filter_copy.filter = types.MethodType(filter_method, filter_copy)
namespace[new_name] = filter_copy
# If filter uses a method, add it to the namespace as well.
if filter_copy.method is not None:
namespace[filter_copy.method] = deepcopy(
getattr(EntityFilter, filter_copy.method)
)

# Create class with added filters.
klass = ResolweFilterMetaclass.__new__(mcs, name, bases, namespace)
return klass


class AnnotationFieldFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass):
class AnnotationFieldFilter(BaseResolweFilter):
"""Filter the AnnotationField endpoint."""

entity = FilterRelatedWithPermissions(EntityFilter, related_path="values__entity")

@classmethod
def filter_for_field(cls, field, field_name, lookup_expr=None):
"""Add permission check for collections lookups.
Expand Down Expand Up @@ -688,9 +733,10 @@ class Meta(BaseResolweFilter.Meta):
}


class AnnotationValueFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMetaclass):
class AnnotationValueFilter(BaseResolweFilter):
"""Filter the AnnotationValue endpoint."""

entity = FilterRelatedWithPermissions(EntityFilter)
label = filters.CharFilter(method="filter_by_label")

def filter_by_label(self, queryset: QuerySet, name: str, value: str):
Expand Down
Loading