From 8b889ec0bb29f2f39067bf3a7f472affccecb00c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gregor=20Jer=C5=A1e?= Date: Sun, 22 Oct 2023 21:34:20 +0200 Subject: [PATCH] Allow modifying entitiy values and add some helper methods --- src/resdk/query.py | 62 +++++++++++++++++++++++------- src/resdk/resolwe.py | 15 ++++++-- src/resdk/resources/annotations.py | 53 ++++++++++++++++--------- src/resdk/resources/base.py | 9 ++++- src/resdk/resources/sample.py | 5 +++ 5 files changed, 108 insertions(+), 36 deletions(-) diff --git a/src/resdk/query.py b/src/resdk/query.py index b5947f14..1e58ae62 100644 --- a/src/resdk/query.py +++ b/src/resdk/query.py @@ -15,7 +15,13 @@ import tqdm -from resdk.resources import DescriptorSchema, Process +from resdk.resources import ( + AnnotationField, + AnnotationValue, + DescriptorSchema, + Process, + Sample, +) from resdk.resources.base import BaseResource @@ -92,6 +98,10 @@ def __init__(self, resolwe, resource, slug_field="slug"): self.logger = logging.getLogger(__name__) + def _non_string_iterable(self, item) -> bool: + """Return thur when item is iterable but not string.""" + return isinstance(item, collections.abc.Iterable) and not isinstance(item, str) + def __getitem__(self, index): """Retrieve an item or slice from the set of results.""" if not isinstance(index, (slice, int)): @@ -159,10 +169,10 @@ def _dehydrate_resources(self, obj): """Iterate through object and replace all objects with their ids.""" if isinstance(obj, BaseResource): return obj.id - if isinstance(obj, list): - return [self._dehydrate_resources(element) for element in obj] if isinstance(obj, dict): return {key: self._dehydrate_resources(value) for key, value in obj.items()} + if self._non_string_iterable(obj): + return [self._dehydrate_resources(element) for element in obj] return obj @@ -172,10 +182,8 @@ def _add_filter(self, filter_): # 'sample' is called 'entity' in the backend. key = key.replace("sample", "entity") value = self._dehydrate_resources(value) - - if isinstance(value, list): + if self._non_string_iterable(value): value = ",".join(map(str, value)) - if self.resource.query_method == "GET": self._filters[key].append(value) elif self.resource.query_method == "POST": @@ -382,6 +390,18 @@ def iterate(self, chunk_size=100, show_progress=False): yield obj +class AnnotationFieldQuery(ResolweQuery): + """Add additional method to the annotation field query.""" + + def from_path(self, full_path: str) -> "AnnotationField": + """Get the AnnotationField from full path. + + :raises LookupError: when field at the specified path does not exist. + """ + group_name, field_name = full_path.split(".", maxsplit=1) + return self.get(name=field_name, group__name=group_name) + + class AnnotationValueQuery(ResolweQuery): """Populate Annotation fields with a single query.""" @@ -393,11 +413,27 @@ def _fetch(self): # Execute the query in a single request. super()._fetch() - # Get corresponding annotation field details in a single query. - field_ids = [value.field_id for value in self._cache] - fields = self.resolwe.annotation_field.filter(id__in=field_ids) - fields_map = {field.id: field for field in fields} + missing = { + value.field_id: value for value in self._cache if value._field is None + } + if missing: + # Get corresponding annotation field details in a single query and attach it to + # the values. + for field in self.resolwe.annotation_field.filter(id__in=missing.keys()): + missing[field.id]._field = field + missing[field.id]._original_values["field"] = field._original_values + + def from_path(self, full_path: str) -> "AnnotationValue": + """Get the AnnotationValue from full path. - # Set the fields on the AnnotationValue instances. - for value in self._cache: - value._field = fields_map[value.field_id] + :raises LookupError: when field at the specified path does not exist. + """ + group_name, field_name = full_path.split(".", maxsplit=1) + return self.get(field__name=field_name, field__group__name=group_name) + + def create_from_path( + self, sample: Sample, full_path: str, value + ) -> AnnotationValue: + """Create annotation value.""" + field = self.resolwe.annotation_field.from_path(full_path) + return self.create(entity=sample, value=value, field=field) diff --git a/src/resdk/resolwe.py b/src/resdk/resolwe.py index 566fee03..08d154f4 100644 --- a/src/resdk/resolwe.py +++ b/src/resdk/resolwe.py @@ -26,7 +26,7 @@ from .constants import CHUNK_SIZE from .exceptions import ValidationError, handle_http_exception -from .query import AnnotationValueQuery, ResolweQuery +from .query import AnnotationFieldQuery, AnnotationValueQuery, ResolweQuery from .resources import ( AnnotationField, AnnotationValue, @@ -94,9 +94,17 @@ class Resolwe: """ + # Map between resource and Query map. Default in ResorweQuery, only overrides must + # be listed here. + resource_query_class = { + AnnotationValue: AnnotationValueQuery, + AnnotationField: AnnotationFieldQuery, + } + # Map resource class to ResolweQuery name resource_query_mapping = { AnnotationField: "annotation_field", + AnnotationValue: "annotation_value", Data: "data", Collection: "collection", Sample: "sample", @@ -169,12 +177,11 @@ def _initialize_queries(self): """Initialize ResolweQuery's.""" for resource, query_name in self.resource_query_mapping.items(): slug_field = self.slug_field_mapping.get(query_name, "slug") - query = ResolweQuery(self, resource, slug_field=slug_field) + QueryClass = self.resource_query_class.get(resource, ResolweQuery) + query = QueryClass(self, resource, slug_field=slug_field) if query_name in self.query_filter_mapping: query = query.filter(**self.query_filter_mapping[query_name]) setattr(self, query_name, query) - # Use custon query to reduce the number of queries. - setattr(self, "annotation_value", AnnotationValueQuery(self, AnnotationValue)) def _login( self, diff --git a/src/resdk/resources/annotations.py b/src/resdk/resources/annotations.py index a74f3218..fc244281 100644 --- a/src/resdk/resources/annotations.py +++ b/src/resdk/resources/annotations.py @@ -1,7 +1,7 @@ """Annotatitons resources.""" import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union from .base import BaseResource @@ -19,8 +19,9 @@ class AnnotationGroup(BaseResource): READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("name", "sort_order", "label") - def __init__(self, resolwe: "Resolwe", **model_data: dict): + def __init__(self, resolwe: "Resolwe", **model_data): """Initialize the instance. + :param resolwe: Resolwe instance :param model_data: Resource model data """ @@ -41,12 +42,13 @@ class AnnotationField(BaseResource): "sort_order", "type", "validator_regex", - "vocubalary", + "vocabulary", "required", ) - def __init__(self, resolwe: "Resolwe", **model_data: dict): + def __init__(self, resolwe: "Resolwe", **model_data): """Initialize the instance. + :param resolwe: Resolwe instance :param model_data: Resource model data """ @@ -56,12 +58,15 @@ def __init__(self, resolwe: "Resolwe", **model_data: dict): super().__init__(resolwe, **model_data) @property - def group(self): + def group(self) -> AnnotationGroup: """Get annotation group.""" + assert ( + self._group is not None + ), "AnnotationGroup must be set before it can be used." return self._group @group.setter - def group(self, payload): + def group(self, payload: dict): """Set annotation group.""" self._resource_setter(payload, AnnotationGroup, "_group") @@ -71,37 +76,49 @@ class AnnotationValue(BaseResource): endpoint = "annotation_value" - READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ( - "field", - "entity", - "value", - "label", - ) + READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("label",) + + UPDATE_PROTECTED_FIELDS = BaseResource.UPDATE_PROTECTED_FIELDS + ("entity", "field") - def __init__(self, resolwe: "Resolwe", **model_data: dict): + WRITABLE_FIELDS = BaseResource.WRITABLE_FIELDS + ("value",) + + def __init__(self, resolwe: "Resolwe", **model_data): """Initialize the instance. + :param resolwe: Resolwe instance :param model_data: Resource model data """ self.logger = logging.getLogger(__name__) #: annotation field - self._field = None - self.field_id = None + self._field: Optional[AnnotationField] = None + self.field_id: Optional[int] = None super().__init__(resolwe, **model_data) @property - def field(self): + def field(self) -> AnnotationField: """Get annotation field.""" if self._field is None: assert ( self.field_id is not None ), "AnnotationField must be set before it can be used." self._field = AnnotationField(self.resolwe, id=self.field_id) + # The field is read-only but we have to modify original values here so save + # can detect there were no changes. + self._original_values["field"] = self._field._original_values return self._field @field.setter - def field(self, payload): + def field(self, payload: Union[int, AnnotationField, dict]): """Set annotation field.""" - self.field_id = payload + field_id = None + if isinstance(payload, int): + field_id = payload + elif isinstance(payload, dict): + field_id = payload["id"] + elif isinstance(payload, AnnotationField): + field_id = payload.id + if field_id != self.field_id: + self._field = None + self.field_id = field_id diff --git a/src/resdk/resources/base.py b/src/resdk/resources/base.py index 0b09fb09..28353ff4 100644 --- a/src/resdk/resources/base.py +++ b/src/resdk/resources/base.py @@ -102,8 +102,8 @@ def save(self): def field_changed(field_name): """Check if local field value is different from the server.""" - original_value = self._original_values.get(field_name, None) current_value = getattr(self, field_name, None) + original_value = self._original_values.get(field_name, None) if isinstance(current_value, BaseResource) and original_value: # TODO: Check that current and original are instances of the same resource class @@ -152,6 +152,13 @@ def assert_fields_unchanged(field_names): if "sample" in payload: payload["entity"] = payload.pop("sample") + from .annotations import AnnotationValue + + # Annotation models have primarykey serializer. + if isinstance(self, AnnotationValue): + payload["field"] = payload["field"]["id"] + payload["entity"] = payload["entity"]["id"] + response = self.api.post(payload) self._update_fields(response) diff --git a/src/resdk/resources/sample.py b/src/resdk/resources/sample.py index a42905e0..57c8f3c7 100644 --- a/src/resdk/resources/sample.py +++ b/src/resdk/resources/sample.py @@ -251,3 +251,8 @@ def duplicate(self, inherit_collection=False): ) background_task = BackgroundTask(resolwe=self.resolwe, **task_data) return self.resolwe.sample.get(id__in=background_task.result()) + + @property + def annotations(self): + """Get the annotations for the given sample.""" + return self.resolwe.annotation_value.filter(entity=self.id)