diff --git a/invenio_records/dumpers/elasticsearch.py b/invenio_records/dumpers/elasticsearch.py index b503e5fc..868a8d7b 100644 --- a/invenio_records/dumpers/elasticsearch.py +++ b/invenio_records/dumpers/elasticsearch.py @@ -151,7 +151,7 @@ def _dump_model_field(self, record, model_field_name, dump, dump_key, dump[dump_key] = self._serialize(val, dump_type) def _load_model_field(self, record_cls, model_field_name, dump, dump_key, - dump_type): + load_type): """Helper method to load model fields from dump. :param record_cls: The record class being used for loading. @@ -169,12 +169,11 @@ def _load_model_field(self, record_cls, model_field_name, dump, dump_key, return val # Determine dump data type if not provided - if dump_type is None: - sa_field = getattr(record_cls.model_cls, model_field_name) - dump_type = self._sa_type(record_cls.model_cls, model_field_name) + if load_type is None: + load_type = self._sa_type(record_cls.model_cls, model_field_name) # Deserialize the value - return self._deserialize(val, dump_type) + return self._deserialize(val, load_type) @staticmethod def _iter_modelfields(record_cls): @@ -247,15 +246,15 @@ def load(self, dump_data, record_cls): # Load explicitly defined model fields. model_data = {} it = self._model_fields.items() - for model_field_name, (dump_key, dump_type) in it: + for model_field_name, (dump_key, load_type) in it: model_data[model_field_name] = self._load_model_field( - record_cls, model_field_name, dump_data, dump_key, dump_type) + record_cls, model_field_name, dump_data, dump_key, load_type) # Load model fields defined as system fields for systemfield in self._iter_modelfields(record_cls): model_data[systemfield.model_field_name] = self._load_model_field( record_cls, systemfield.model_field_name, dump_data, - systemfield.dump_key, systemfield.dump_type) + systemfield.dump_key, systemfield.load_type) # Initialize model if an id was provided. if model_data.get('id') is not None: diff --git a/invenio_records/systemfields/model.py b/invenio_records/systemfields/model.py index efc74c3c..3b410a84 100644 --- a/invenio_records/systemfields/model.py +++ b/invenio_records/systemfields/model.py @@ -16,7 +16,7 @@ class ModelField(SystemField): """Model field for providing get and set access on a model field.""" def __init__(self, model_field_name=None, dump=True, dump_key=None, - dump_type=None): + dump_type=None, load_type=None): """Initialize the field. :param model_field_name: Name of field on the database model. @@ -29,6 +29,7 @@ def __init__(self, model_field_name=None, dump=True, dump_key=None, self.dump = dump self._dump_key = dump_key self._dump_type = dump_type + self._load_type = load_type # # Helpers @@ -58,6 +59,14 @@ def dump_type(self): """ return self._dump_type + @property + def load_type(self): + """The data type used to determine how to deserialize the model field. + + Defaults to dump_type. + """ + return self._load_type or self._dump_type + def _set(self, model, value): """Internal method to set value on the model's field.""" setattr(model, self.model_field_name, value) diff --git a/tests/test_api_dumpers.py b/tests/test_api_dumpers.py index e6e69cf8..b7aba26b 100644 --- a/tests/test_api_dumpers.py +++ b/tests/test_api_dumpers.py @@ -9,15 +9,20 @@ """Test the dumpers API.""" from datetime import date, datetime +from enum import Enum from uuid import UUID import pytest +from invenio_db import db from sqlalchemy.dialects import mysql +from sqlalchemy_utils.types import ChoiceType from invenio_records.api import Record from invenio_records.dumpers import ElasticsearchDumper, ElasticsearchDumperExt from invenio_records.dumpers.relations import RelationDumperExt from invenio_records.models import RecordMetadataBase +from invenio_records.systemfields import SystemFieldsMixin +from invenio_records.systemfields.model import ModelField from invenio_records.systemfields.relations import PKListRelation, \ PKRelation, RelationsField @@ -215,3 +220,40 @@ class RecordWithRelations(Record): # Load it # new_record = Record.loads(dump, loader=dumper) # assert 'count' not in new_record + + +def test_load_dump_type(testapp): + dumper = ElasticsearchDumper() + rec = TestRecord.create({}, test=EnumTestModel.REGISTERED) + # Serialize + dumped_data = dumper.dump(rec, {}) + assert isinstance(dumped_data["test"], str) + # Deserialize + loaded_data = dumper.load(dumped_data, TestRecord) + assert isinstance(loaded_data.test, EnumTestModel) + + +# Similar to PIDStatus +class EnumTestModel(Enum): + NEW = "N" + REGISTERED = "R" + + def __init__(self, value): + """Hack.""" + + def __str__(self): + """Return its value.""" + return self.value + + +class TestMetadata(db.Model, RecordMetadataBase): + """Represent a record metadata.""" + + __tablename__ = 'test_dumper_table' + test = db.Column(ChoiceType(EnumTestModel, impl=db.CHAR(1))) + + +# Similar to ModelPIDField +class TestRecord(Record, SystemFieldsMixin): + model_cls = TestMetadata + test = ModelField(dump_type=str)