diff --git a/invenio_records_lom/utils/__init__.py b/invenio_records_lom/utils/__init__.py index c6a3b6f..4f69fd6 100644 --- a/invenio_records_lom/utils/__init__.py +++ b/invenio_records_lom/utils/__init__.py @@ -16,6 +16,7 @@ create_record, get_learningresourcetypedict, get_oefosdict, + update_record, ) from .vcard import make_lom_vcard @@ -30,5 +31,6 @@ "build_record_unique_id", "check_about_duplicate", "create_record", + "update_record", "LOMDuplicateRecordError", ) diff --git a/invenio_records_lom/utils/util.py b/invenio_records_lom/utils/util.py index 823168d..4121a94 100644 --- a/invenio_records_lom/utils/util.py +++ b/invenio_records_lom/utils/util.py @@ -18,14 +18,11 @@ from typing import Any, Iterator, Optional, Union from flask_principal import Identity -from invenio_drafts_resources.records import Draft from invenio_records_resources.services.base import Service from invenio_search import RecordsSearch from invenio_search.engine import dsl from marshmallow.exceptions import ValidationError -# from .. import records, services # due to circular imports no direct import possible - class DotAccessWrapper(MutableMapping): """Provides getting/setting for passed-in mapping via dot-notated keys. @@ -301,14 +298,10 @@ def create_record( identity: Identity, *, do_publish: bool = True, - pre_created_draft: Draft = None, # records.LOMDraft ): """Create record.""" - if pre_created_draft: - draft = service.update_draft(identity, id_=pre_created_draft.id, data=data) - else: - are_files = len(file_paths) > 0 - draft = service.create(data=data, identity=identity, files=are_files) + are_files = len(file_paths) > 0 + draft = service.create(data=data, identity=identity, files=are_files) try: for file_path in file_paths: @@ -330,3 +323,18 @@ def create_record( raise error return draft + + +def update_record( + pid: str, + service: Service, # services.LOMRecordService + data: dict, + identity: Identity, + *, + do_publish: bool = True, +): + """Update record.""" + service.update_draft(id_=pid, data=data, identity=identity) + + if do_publish: + service.publish(id_=pid, identity=identity)