Skip to content

Commit

Permalink
Allow modifying entitiy values
Browse files Browse the repository at this point in the history
and add some helper methods
  • Loading branch information
gregorjerse committed Oct 24, 2023
1 parent 90d3ce2 commit 8b889ec
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 36 deletions.
62 changes: 49 additions & 13 deletions src/resdk/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import tqdm

from resdk.resources import DescriptorSchema, Process
from resdk.resources import (
AnnotationField,
AnnotationValue,
DescriptorSchema,
Process,
Sample,
)
from resdk.resources.base import BaseResource


Expand Down Expand Up @@ -92,6 +98,10 @@ def __init__(self, resolwe, resource, slug_field="slug"):

self.logger = logging.getLogger(__name__)

def _non_string_iterable(self, item) -> bool:
"""Return thur when item is iterable but not string."""
return isinstance(item, collections.abc.Iterable) and not isinstance(item, str)

def __getitem__(self, index):
"""Retrieve an item or slice from the set of results."""
if not isinstance(index, (slice, int)):
Expand Down Expand Up @@ -159,10 +169,10 @@ def _dehydrate_resources(self, obj):
"""Iterate through object and replace all objects with their ids."""
if isinstance(obj, BaseResource):
return obj.id
if isinstance(obj, list):
return [self._dehydrate_resources(element) for element in obj]
if isinstance(obj, dict):
return {key: self._dehydrate_resources(value) for key, value in obj.items()}
if self._non_string_iterable(obj):
return [self._dehydrate_resources(element) for element in obj]

return obj

Expand All @@ -172,10 +182,8 @@ def _add_filter(self, filter_):
# 'sample' is called 'entity' in the backend.
key = key.replace("sample", "entity")
value = self._dehydrate_resources(value)

if isinstance(value, list):
if self._non_string_iterable(value):
value = ",".join(map(str, value))

if self.resource.query_method == "GET":
self._filters[key].append(value)
elif self.resource.query_method == "POST":
Expand Down Expand Up @@ -382,6 +390,18 @@ def iterate(self, chunk_size=100, show_progress=False):
yield obj


class AnnotationFieldQuery(ResolweQuery):
"""Add additional method to the annotation field query."""

def from_path(self, full_path: str) -> "AnnotationField":
"""Get the AnnotationField from full path.
:raises LookupError: when field at the specified path does not exist.
"""
group_name, field_name = full_path.split(".", maxsplit=1)
return self.get(name=field_name, group__name=group_name)


class AnnotationValueQuery(ResolweQuery):
"""Populate Annotation fields with a single query."""

Expand All @@ -393,11 +413,27 @@ def _fetch(self):
# Execute the query in a single request.
super()._fetch()

# Get corresponding annotation field details in a single query.
field_ids = [value.field_id for value in self._cache]
fields = self.resolwe.annotation_field.filter(id__in=field_ids)
fields_map = {field.id: field for field in fields}
missing = {
value.field_id: value for value in self._cache if value._field is None
}
if missing:
# Get corresponding annotation field details in a single query and attach it to
# the values.
for field in self.resolwe.annotation_field.filter(id__in=missing.keys()):
missing[field.id]._field = field
missing[field.id]._original_values["field"] = field._original_values

def from_path(self, full_path: str) -> "AnnotationValue":
"""Get the AnnotationValue from full path.
# Set the fields on the AnnotationValue instances.
for value in self._cache:
value._field = fields_map[value.field_id]
:raises LookupError: when field at the specified path does not exist.
"""
group_name, field_name = full_path.split(".", maxsplit=1)
return self.get(field__name=field_name, field__group__name=group_name)

def create_from_path(
self, sample: Sample, full_path: str, value
) -> AnnotationValue:
"""Create annotation value."""
field = self.resolwe.annotation_field.from_path(full_path)
return self.create(entity=sample, value=value, field=field)
15 changes: 11 additions & 4 deletions src/resdk/resolwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from .constants import CHUNK_SIZE
from .exceptions import ValidationError, handle_http_exception
from .query import AnnotationValueQuery, ResolweQuery
from .query import AnnotationFieldQuery, AnnotationValueQuery, ResolweQuery
from .resources import (
AnnotationField,
AnnotationValue,
Expand Down Expand Up @@ -94,9 +94,17 @@ class Resolwe:
"""

# Map between resource and Query map. Default in ResorweQuery, only overrides must
# be listed here.
resource_query_class = {
AnnotationValue: AnnotationValueQuery,
AnnotationField: AnnotationFieldQuery,
}

# Map resource class to ResolweQuery name
resource_query_mapping = {
AnnotationField: "annotation_field",
AnnotationValue: "annotation_value",
Data: "data",
Collection: "collection",
Sample: "sample",
Expand Down Expand Up @@ -169,12 +177,11 @@ def _initialize_queries(self):
"""Initialize ResolweQuery's."""
for resource, query_name in self.resource_query_mapping.items():
slug_field = self.slug_field_mapping.get(query_name, "slug")
query = ResolweQuery(self, resource, slug_field=slug_field)
QueryClass = self.resource_query_class.get(resource, ResolweQuery)
query = QueryClass(self, resource, slug_field=slug_field)
if query_name in self.query_filter_mapping:
query = query.filter(**self.query_filter_mapping[query_name])
setattr(self, query_name, query)
# Use custon query to reduce the number of queries.
setattr(self, "annotation_value", AnnotationValueQuery(self, AnnotationValue))

def _login(
self,
Expand Down
53 changes: 35 additions & 18 deletions src/resdk/resources/annotations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Annotatitons resources."""

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Union

from .base import BaseResource

Expand All @@ -19,8 +19,9 @@ class AnnotationGroup(BaseResource):

READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("name", "sort_order", "label")

def __init__(self, resolwe: "Resolwe", **model_data: dict):
def __init__(self, resolwe: "Resolwe", **model_data):
"""Initialize the instance.
:param resolwe: Resolwe instance
:param model_data: Resource model data
"""
Expand All @@ -41,12 +42,13 @@ class AnnotationField(BaseResource):
"sort_order",
"type",
"validator_regex",
"vocubalary",
"vocabulary",
"required",
)

def __init__(self, resolwe: "Resolwe", **model_data: dict):
def __init__(self, resolwe: "Resolwe", **model_data):
"""Initialize the instance.
:param resolwe: Resolwe instance
:param model_data: Resource model data
"""
Expand All @@ -56,12 +58,15 @@ def __init__(self, resolwe: "Resolwe", **model_data: dict):
super().__init__(resolwe, **model_data)

@property
def group(self):
def group(self) -> AnnotationGroup:
"""Get annotation group."""
assert (
self._group is not None
), "AnnotationGroup must be set before it can be used."
return self._group

@group.setter
def group(self, payload):
def group(self, payload: dict):
"""Set annotation group."""
self._resource_setter(payload, AnnotationGroup, "_group")

Expand All @@ -71,37 +76,49 @@ class AnnotationValue(BaseResource):

endpoint = "annotation_value"

READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + (
"field",
"entity",
"value",
"label",
)
READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("label",)

UPDATE_PROTECTED_FIELDS = BaseResource.UPDATE_PROTECTED_FIELDS + ("entity", "field")

def __init__(self, resolwe: "Resolwe", **model_data: dict):
WRITABLE_FIELDS = BaseResource.WRITABLE_FIELDS + ("value",)

def __init__(self, resolwe: "Resolwe", **model_data):
"""Initialize the instance.
:param resolwe: Resolwe instance
:param model_data: Resource model data
"""
self.logger = logging.getLogger(__name__)

#: annotation field
self._field = None
self.field_id = None
self._field: Optional[AnnotationField] = None
self.field_id: Optional[int] = None

super().__init__(resolwe, **model_data)

@property
def field(self):
def field(self) -> AnnotationField:
"""Get annotation field."""
if self._field is None:
assert (
self.field_id is not None
), "AnnotationField must be set before it can be used."
self._field = AnnotationField(self.resolwe, id=self.field_id)
# The field is read-only but we have to modify original values here so save
# can detect there were no changes.
self._original_values["field"] = self._field._original_values
return self._field

@field.setter
def field(self, payload):
def field(self, payload: Union[int, AnnotationField, dict]):
"""Set annotation field."""
self.field_id = payload
field_id = None
if isinstance(payload, int):
field_id = payload
elif isinstance(payload, dict):
field_id = payload["id"]
elif isinstance(payload, AnnotationField):
field_id = payload.id
if field_id != self.field_id:
self._field = None
self.field_id = field_id
9 changes: 8 additions & 1 deletion src/resdk/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def save(self):

def field_changed(field_name):
"""Check if local field value is different from the server."""
original_value = self._original_values.get(field_name, None)
current_value = getattr(self, field_name, None)
original_value = self._original_values.get(field_name, None)

if isinstance(current_value, BaseResource) and original_value:
# TODO: Check that current and original are instances of the same resource class
Expand Down Expand Up @@ -152,6 +152,13 @@ def assert_fields_unchanged(field_names):
if "sample" in payload:
payload["entity"] = payload.pop("sample")

from .annotations import AnnotationValue

# Annotation models have primarykey serializer.
if isinstance(self, AnnotationValue):
payload["field"] = payload["field"]["id"]
payload["entity"] = payload["entity"]["id"]

response = self.api.post(payload)
self._update_fields(response)

Expand Down
5 changes: 5 additions & 0 deletions src/resdk/resources/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,8 @@ def duplicate(self, inherit_collection=False):
)
background_task = BackgroundTask(resolwe=self.resolwe, **task_data)
return self.resolwe.sample.get(id__in=background_task.result())

@property
def annotations(self):
"""Get the annotations for the given sample."""
return self.resolwe.annotation_value.filter(entity=self.id)

0 comments on commit 8b889ec

Please sign in to comment.