Skip to content

Commit

Permalink
Add annotations property to entity
Browse files Browse the repository at this point in the history
in Python processes. It can be used as dictionary to read/set entity
annotations.
  • Loading branch information
gregorjerse committed Oct 9, 2023
1 parent 504926e commit 15bb805
Show file tree
Hide file tree
Showing 10 changed files with 528 additions and 13 deletions.
4 changes: 4 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Added
-----
- Allow sorting entities by collection name

Changed
-------
- Add ``annotations`` property to the ``Entity`` object in Python processes


===================
36.1.0 - 2023-09-14
Expand Down
13 changes: 5 additions & 8 deletions resolwe/flow/managers/listener/basic_commands_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from resolwe.flow.executors.socket_utils import Message, Response
from resolwe.flow.managers.protocol import ExecutorProtocol
from resolwe.flow.models import Data, DataDependency, Entity, Process, Worker
from resolwe.flow.models import Data, DataDependency, Process, Worker
from resolwe.flow.models.utils import validate_data_object
from resolwe.flow.utils import dict_dot, iterate_fields, iterate_schema
from resolwe.storage.connectors import connectors
Expand Down Expand Up @@ -490,14 +490,11 @@ def handle_annotate(
if entity_id is None:
raise RuntimeError(f"No entity to annotate for object '{data_id}'")

# The entity must be retrieved and saved or schema validation on the
# entity must be done manually.
entity = Entity.objects.get(pk=entity_id)
for key, val in message.message_data.items():
dict_dot(entity.descriptor, key, val)
entity.save()
handler = self.plugin_manager.get_handler("set_entity_annotations")
assert handler is not None, "Handler for 'set_entity_annotations' not found"

return message.respond_ok("OK")
message.message_data = [entity_id, message.message_data, True]
return handler(data_id, message, manager)

def handle_process_log(
self, data_id, message: Message[dict], manager: "Processor"
Expand Down
95 changes: 92 additions & 3 deletions resolwe/flow/managers/listener/python_process_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,26 @@
from base64 import b64encode
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from zipfile import ZIP_STORED, ZipFile

from django.apps import apps
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.postgres.fields.jsonb import JSONField as JSONFieldb
from django.db.models import ForeignKey, JSONField, ManyToManyField, Model
from django.db.models import ForeignKey, JSONField, ManyToManyField, Model, Q, Value
from django.db.models.functions import Concat

from resolwe.flow.executors import constants
from resolwe.flow.executors.socket_utils import Message, Response, retry
from resolwe.flow.managers.listener.permission_plugin import permission_manager
from resolwe.flow.models import Collection, Process
from resolwe.flow.models import (
AnnotationField,
AnnotationValue,
Collection,
Entity,
Process,
)
from resolwe.flow.models.base import UniqueSlugError
from resolwe.flow.models.utils import serialize_collection_relations
from resolwe.flow.utils import dict_dot
Expand Down Expand Up @@ -159,6 +166,88 @@ def handle_filter_objects(
list(filtered_objects.order_by(*sorting).values_list(*attributes))
)

def handle_set_entity_annotations(
self,
data_id: int,
message: Message[tuple[int, dict[str, Any], bool]],
manager: "Processor",
):
"""Handle update entity annotation request.
The first part of the message is the id of the entity and the second one the
mapping representing annotations. The keys in the are in the format
f"{group_name}.{field_name}".
"""
entity_id, annotations, update = message.message_data
entity = Entity.objects.get(pk=entity_id)
# Check that the user has the permissions to update the entity.
self._permission_manager.can_update(
manager.contributor(data_id), "flow.Entity", entity, {}, data_id
)
# Delete all annotations except the ones that are in the request.
if not update:
entity.annotations.all().delete()

annotation_values: list[AnnotationValue] = []
for field_path, value in annotations.items():
field = AnnotationField.field_from_path(field_path)
if field is None:
raise ValueError(f"Invalid field path: '{field_path}'.")
value = AnnotationValue(entity_id=entity_id, field=field, value=value)
# The validation is necessary since values are bulk-created and validation
# is done inside the save method, which is not called.
value.validate()
annotation_values.append(value)
if annotation_values:
AnnotationValue.objects.bulk_create(
annotation_values,
update_conflicts=True,
update_fields=["_value"],
unique_fields=["entity", "field"],
)
return message.respond_ok("OK")

def handle_get_entity_annotations(
self,
data_id: int,
message: Message[tuple[int, Optional[list[str]]]],
manager: "Processor",
) -> Response[dict[str, Any]]:
"""Handle get annotations request.
The first part of the message is the id of the entity and the second one the
list representing annotations to retrieve. The annotations are in the format
f"{group_name}.{field_name}".
"""
entity_id, annotations = message.message_data
# Check that the user has the permissions to read the entity.
entity = self._permission_manager.filter_objects(
manager.contributor(data_id),
"flow.Entity",
Entity.objects.filter(pk=entity_id),
data_id,
).get()
entity_annotations = entity.annotations
if annotations is not None:
if len(annotations) == 0:
entity_annotations = entity_annotations.none()
else:
# Filter only requested annotations. Return only annotations matching group
# and field name.
annotation_filter = Q()
for annotation in annotations:
group_name, field_name = annotation.split(".")
annotation_filter |= Q(
field__group__name=group_name, field__name=field_name
)
entity_annotations.filter(annotation_filter)
to_return = dict(
entity_annotations.annotate(
full_path=Concat("field__group__name", Value("."), "field__name")
).values_list("full_path", "_value__value")
)
return message.respond_ok(to_return)

def handle_update_model_fields(
self,
data_id: int,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.2 on 2023-10-08 20:19

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("flow", "0015_alter_annotationfield_vocabulary"),
]

operations = [
migrations.AddConstraint(
model_name="annotationvalue",
constraint=models.UniqueConstraint(
fields=("entity", "field"), name="uniquetogether_entity_field"
),
),
]
9 changes: 9 additions & 0 deletions resolwe/flow/models/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ def __str__(self):
class AnnotationValue(AuditModel):
"""The value of the annotation."""

class Meta:
"""Set the unique constraints."""

constraints = [
models.constraints.UniqueConstraint(
fields=["entity", "field"], name="uniquetogether_entity_field"
),
]

#: the entity this field belongs to
entity: "Entity" = models.ForeignKey(
"Entity", related_name="annotations", on_delete=models.CASCADE
Expand Down
1 change: 1 addition & 0 deletions resolwe/flow/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def test_empty_vocabulary(self):
self.annotation_value1.refresh_from_db()
# Labels must be recomputed to the original value.
self.assertEqual(self.annotation_value1._value["label"], "string")
self.entity1.annotations.all().delete()
AnnotationValue.objects.create(
entity=self.entity1, field=self.annotation_field1, value="non_existing"
)
Expand Down
89 changes: 88 additions & 1 deletion resolwe/process/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __new__(mcs, name, bases, namespace, **kwargs):
fields_details = communicator.get_model_fields_details(app_name, name)
for field_name in fields_details:
field_type, required, related_model_name = fields_details[field_name]
if field_type in FIELDS_MAP:
# Do not overwrite explicitely defined class attributes.
if field_type in FIELDS_MAP and not field_name in model.__dict__:
kwargs = {"required": required}
if field_type == "ForeignKey":
id_field = IntegerField()
Expand Down Expand Up @@ -486,12 +487,98 @@ def get_latest(cls, slug: str) -> "Model":
return cls(pks[0][0])


class EntityAnnotation(MutableMapping[str, Any]):
"""Annotations on sample object."""

def __init__(self, entity: "Entity"):
"""Initialize empty cache."""
self._entity = entity
self._cache: Dict[str, Any] = dict()
self._all_read = False

def __getitem__(self, key: str) -> Any:
"""Get the value for the given key."""
if key not in self._cache:
annotations = communicator.get_entity_annotations(self._entity.id, [key])
if key in annotations:
self._cache[key] = annotations[key]
# Return None even if key is not known.
return self._cache.get(key)

def copy(self) -> dict:
"""Return a dictionary with all annotation values."""
self._fetch_all()
return self._cache.copy()

def __setitem__(self, key: str, value: Any):
"""Set the value for the given key."""
communicator.set_entity_annotations(self._entity.id, {key: value}, True)
self._cache[key] = value

def _set_annotations(self, data: Dict[str, Any]):
"""Bulk set annotations (delete others)."""
communicator.set_entity_annotations(self._entity.id, data, False)
self._cache = data.copy()
self._all_read = True

def update(self, data: Dict[str, Any]):
"""Bulk update annotations."""
communicator.set_entity_annotations(self._entity.id, data, True)
self._cache.update(data)

def __delitem__(self, name: str):
"""Delete the annotation."""
raise NotImplementedError("Delete is not implemented")

def _fetch_all(self):
"""Read all annotations from the server and cache them."""
if not self._all_read:
self._cache = communicator.get_entity_annotations(self._entity.id, None)
self._all_read = True

def __iter__(self):
"""Return iterator.
Read all the annotations from the backend and iterate over them.
"""
self._fetch_all()
return iter(self._cache)

def __len__(self) -> int:
"""Return the number of annotations.
Beware: this method can read all the annotations from the backend.
"""
self._fetch_all()
return len(self._cache)

def __str__(self) -> str:
"""Return string representation."""
return f"EntityAnnotations({self._entity})"


class Entity(Model):
"""Entity model."""

_app_name = "flow"
_model_name = "Entity"

@property
def annotations(self):
"""Return entity annotations object."""
if not hasattr(self, "_annotations"):
self._annotations = EntityAnnotation(self)
return self._annotations

@annotations.setter
def annotations(self, value: dict):
"""Set annotations on entity.
Annotations not present in the dictionary will be deleted.
"""
assert isinstance(value, dict), "Annotations must be stored in a dictionary."
self._annotations._set_annotations(value)


class Storage(Model):
"""Storage model."""
Expand Down
75 changes: 75 additions & 0 deletions resolwe/process/tests/processes/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,81 @@ def run(self, inputs, outputs):
self.update_entity_descriptor(annotations)


class AnnotateProcessV2(Process):
slug = "test-python-process-annotate-entity-v2"
name = "Test Python Process Annotate Entity V2"
version = "0.0.1"
process_type = "data:python:annotate"
entity = {
"type": "sample",
}

def run(self, inputs, outputs):
self.data.entity.annotations["general.species"] = "Human"
self.data.entity.annotations["general.age"] = 42


class AnnotateProcessV2BulkSet(Process):
slug = "test-python-process-annotate-entity-v2-bulk-set"
name = "Test Python Process Annotate Entity V2 bulk-set"
version = "0.0.1"
process_type = "data:python:annotate"
entity = {
"type": "sample",
}

def run(self, inputs, outputs):
self.data.entity.annotations.update(
{
"general.species": "Human Bulk",
"general.age": 2 * 42,
}
)
self.data.entity.annotations = {"general.species": "Human Bulk Set"}


class AnnotateProcessV2BulkUpdate(Process):
slug = "test-python-process-annotate-entity-v2-bulk-update"
name = "Test Python Process Annotate Entity V2 bulk-update"
version = "0.0.1"
process_type = "data:python:annotate"
entity = {
"type": "sample",
}

def run(self, inputs, outputs):
self.data.entity.annotations.update(
{
"general.species": "Human Bulk",
"general.age": 2 * 42,
}
)


class AnnotateProcessUpdateV2(Process):
slug = "test-python-process-update-entity-annotations-v2"
name = "Test Python Process Update Annotations Entity V2"
version = "0.0.1"
process_type = "data:python:annotate"
entity = {
"type": "sample",
}

class Input:
entity_id = IntegerField(label="Entity id")

class Output:
"""Output fields."""

existing_annotations = StringField(label="Existing annotations")

def run(self, inputs, outputs):
entity = Entity.get(pk=inputs.entity_id)
outputs.existing_annotations = str(entity.annotations.copy())
entity.annotations["general.species"] = "Human"
entity.annotations["general.age"] = entity.annotations["general.age"] // 2


class FileProcess(Process):
slug = "test-python-process-file"
name = "Test Python Process File"
Expand Down
Loading

0 comments on commit 15bb805

Please sign in to comment.