diff --git a/aiida_restapi/exceptions.py b/aiida_restapi/exceptions.py new file mode 100644 index 0000000..8f0e841 --- /dev/null +++ b/aiida_restapi/exceptions.py @@ -0,0 +1,31 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""This file contains the exceptions that are raised by the RESTapi at the +highest level, namely that of the interaction with the client. Their +specificity resides into the fact that they return a message that is embedded +into the HTTP response. + +Example: +-------- + .../api/v1/nodes/ ... (TODO compete this with an actual example) + +Other errors arising at deeper level, e.g. those raised by the QueryBuilder +or internal errors, are not embedded into the HTTP response. + +""" + +from aiida.common.exceptions import FeatureNotAvailable, InputValidationError + + +class RestInputValidationError(InputValidationError): + """If inputs passed in query strings are wrong""" + + +class RestFeatureNotAvailable(FeatureNotAvailable): + """If endpoint is not emplemented for given node type""" diff --git a/aiida_restapi/identifiers.py b/aiida_restapi/identifiers.py new file mode 100644 index 0000000..e00a466 --- /dev/null +++ b/aiida_restapi/identifiers.py @@ -0,0 +1,170 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utility functions to work with node "full types" which identify node types. + +A node's `full_type` is defined as a string that uniquely defines the node type. A valid `full_type` is constructed by +concatenating the `node_type` and `process_type` of a node with the `FULL_TYPE_CONCATENATOR`. Each segment of the full +type can optionally be terminated by a single `LIKE_OPERATOR_CHARACTER` to indicate that the `node_type` or +`process_type` should start with that value but can be followed by any amount of other characters. A full type is +invalid if it does not contain exactly one `FULL_TYPE_CONCATENATOR` character. Additionally, each segment can contain +at most one occurrence of the `LIKE_OPERATOR_CHARACTER` and it has to be at the end of the segment. + +Examples of valid full types: + + 'data.bool.Bool.|' + 'process.calculation.calcfunction.%|%' + 'process.calculation.calcjob.CalcJobNode.|aiida.calculations:arithmetic.add' + 'process.calculation.calcfunction.CalcFunctionNode.|aiida.workflows:codtools.primitive_structure_from_cif' + +Examples of invalid full types: + + 'data.bool' # Only a single segment without concatenator + 'data.|bool.Bool.|process.' # More than one concatenator + 'process.calculation%.calcfunction.|aiida.calculations:arithmetic.add' # Like operator not at end of segment + 'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment + +""" + +from typing import Any + +from aiida.common.escaping import escape_for_sql_like + +FULL_TYPE_CONCATENATOR = '|' +LIKE_OPERATOR_CHARACTER = '%' +DEFAULT_NAMESPACE_LABEL = '~no-entry-point~' + + +def validate_full_type(full_type: str) -> None: + """Validate that the `full_type` is a valid full type unique node identifier. + + :param full_type: a `Node` full type + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + """ + from aiida.common.lang import type_check + + type_check(full_type, str) + + if FULL_TYPE_CONCATENATOR not in full_type: + raise ValueError( + f'full type `{full_type}` does not include the required concatenator symbol `{FULL_TYPE_CONCATENATOR}`.' + ) + elif full_type.count(FULL_TYPE_CONCATENATOR) > 1: + raise ValueError( + f'full type `{full_type}` includes the concatenator symbol `{FULL_TYPE_CONCATENATOR}` more than once.' + ) + + +def construct_full_type(node_type: str, process_type: str) -> str: + """Return the full type, which fully identifies the type of any `Node` with the given `node_type` and + `process_type`. + + :param node_type: the `node_type` of the `Node` + :param process_type: the `process_type` of the `Node` + :return: the full type, which is a unique identifier + """ + if node_type is None: + node_type = '' + + if process_type is None: + process_type = '' + + return f'{node_type}{FULL_TYPE_CONCATENATOR}{process_type}' + + +def get_full_type_filters(full_type: str) -> dict[str, Any]: + """Return the `QueryBuilder` filters that will return all `Nodes` identified by the given `full_type`. + + :param full_type: the `full_type` node type identifier + :return: dictionary of filters to be passed for the `filters` keyword in `QueryBuilder.append` + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + """ + validate_full_type(full_type) + + filters: dict[str, Any] = {} + node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) + + for entry in (node_type, process_type): + if entry.count(LIKE_OPERATOR_CHARACTER) > 1: + raise ValueError(f'full type component `{entry}` contained more than one like-operator character') + + if LIKE_OPERATOR_CHARACTER in entry and entry[-1] != LIKE_OPERATOR_CHARACTER: + raise ValueError(f'like-operator character in full type component `{entry}` is not at the end') + + if LIKE_OPERATOR_CHARACTER in node_type: + # Remove the trailing `LIKE_OPERATOR_CHARACTER`, escape the string and reattach the character + node_type = node_type[:-1] + node_type = escape_for_sql_like(node_type) + LIKE_OPERATOR_CHARACTER + filters['node_type'] = {'like': node_type} + else: + filters['node_type'] = {'==': node_type} + + if LIKE_OPERATOR_CHARACTER in process_type: + # Remove the trailing `LIKE_OPERATOR_CHARACTER` () + # If that was the only specification, just ignore this filter (looking for any process_type) + # If there was more: escape the string and reattach the character + process_type = process_type[:-1] + if process_type: + process_type = escape_for_sql_like(process_type) + LIKE_OPERATOR_CHARACTER + filters['process_type'] = {'like': process_type} + elif process_type: + filters['process_type'] = {'==': process_type} + else: + # A `process_type=''` is used to represents both `process_type='' and `process_type=None`. + # This is because there is no simple way to single out null `process_types`, and therefore + # we consider them together with empty-string process_types. + # Moreover, the existence of both is most likely a bug of migrations and thus both share + # this same "erroneous" origin. + filters['process_type'] = {'or': [{'==': ''}, {'==': None}]} + + return filters + + +def load_entry_point_from_full_type(full_type: str) -> Any: + """Return the loaded entry point for the given `full_type` unique node type identifier. + + :param full_type: the `full_type` unique node type identifier + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + :raises `~aiida.common.exceptions.EntryPointError`: if the corresponding entry point cannot be loaded + """ + from aiida.common import EntryPointError + from aiida.common.utils import strip_prefix + from aiida.plugins.entry_point import ( + is_valid_entry_point_string, + load_entry_point, + load_entry_point_from_string, + ) + + data_prefix = 'data.' + + validate_full_type(full_type) + + node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) + + if is_valid_entry_point_string(process_type): + try: + return load_entry_point_from_string(process_type) + except EntryPointError: + raise EntryPointError(f'could not load entry point `{process_type}`') + + elif node_type.startswith(data_prefix): + base_name = strip_prefix(node_type, data_prefix) + entry_point_name = base_name.rsplit('.', 2)[0] + + try: + return load_entry_point('aiida.data', entry_point_name) + except EntryPointError: + raise EntryPointError(f'could not load entry point `{process_type}`') + + # Here we are dealing with a `ProcessNode` with a `process_type` that is not an entry point string. + # Which means it is most likely a full module path (the fallback option) and we cannot necessarily load the + # class from this. We could try with `importlib` but not sure that we should + raise EntryPointError('entry point of the given full type cannot be loaded') diff --git a/aiida_restapi/resources.py b/aiida_restapi/resources.py new file mode 100644 index 0000000..7713b3b --- /dev/null +++ b/aiida_restapi/resources.py @@ -0,0 +1,41 @@ +from typing import Union + +from aiida.common.exceptions import EntryPointError, LoadingEntryPointError +from aiida.plugins.entry_point import get_entry_point_names, load_entry_point + +from aiida_restapi.exceptions import RestFeatureNotAvailable, RestInputValidationError +from aiida_restapi.identifiers import construct_full_type, load_entry_point_from_full_type + + +def get_all_download_formats(full_type: Union[str, None] = None) -> dict: + """Returns dict of possible node formats for all available node types""" + all_formats = {} + + if full_type: + try: + node_cls = load_entry_point_from_full_type(full_type) + except (TypeError, ValueError): + raise RestInputValidationError(f'The full type {full_type} is invalid.') + except EntryPointError: + raise RestFeatureNotAvailable('The download formats for this node type are not available.') + + try: + available_formats = node_cls.get_export_formats() + all_formats[full_type] = available_formats + except AttributeError: + pass + else: + entry_point_group = 'aiida.data' + + for name in get_entry_point_names(entry_point_group): + try: + node_cls = load_entry_point(entry_point_group, name) + available_formats = node_cls.get_export_formats() + except (AttributeError, LoadingEntryPointError): + continue + + if available_formats: + full_type = construct_full_type(node_cls.class_node_type, '') + all_formats[full_type] = available_formats + + return all_formats diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 42e5d62..bedbb42 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -4,7 +4,7 @@ import os import tempfile from pathlib import Path -from typing import List, Optional +from typing import Any, List, Optional from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv @@ -13,7 +13,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from pydantic import ValidationError -from aiida_restapi import models +from aiida_restapi import models, resources from .auth import get_current_active_user @@ -34,6 +34,13 @@ async def get_nodes_projectable_properties() -> List[str]: return models.Node.get_projectable_properties() +@router.get('/nodes/download_formats', response_model=dict[str, Any]) +async def get_nodes_download_formats() -> dict[str, Any]: + """Get download formats for nodes endpoint""" + + return resources.get_all_download_formats() + + @router.get('/nodes/{nodes_id}', response_model=models.Node) @with_dbenv() async def read_node(nodes_id: int) -> Optional[models.Node]: diff --git a/docs/source/user_guide/graphql.md b/docs/source/user_guide/graphql.md index 18214a5..48ac8ba 100644 --- a/docs/source/user_guide/graphql.md +++ b/docs/source/user_guide/graphql.md @@ -377,7 +377,7 @@ NOT YET SPECIFICALLY IMPLEMENTED http://localhost:5000/api/v4/nodes/download_formats ``` -NOT YET IMPLEMENTED +Not implemented for GraphQL, please use the REST API for this use case. ```html diff --git a/tests/test_nodes.py b/tests/test_nodes.py index fba4831..5429b6e 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -28,6 +28,43 @@ def test_get_nodes_projectable(client): ] +def test_get_download_formats(client): + """Test get download formats for nodes.""" + response = client.get('/nodes/download_formats') + + assert response.status_code == 200 + + reference = { + 'data.core.array.ArrayData.|': ['json'], + 'data.core.array.bands.BandsData.|': [ + 'agr', + 'agr_batch', + 'dat_blocks', + 'dat_multicolumn', + 'gnuplot', + 'json', + 'mpl_pdf', + 'mpl_png', + 'mpl_singlefile', + 'mpl_withjson', + ], + 'data.core.array.kpoints.KpointsData.|': ['json'], + 'data.core.array.projection.ProjectionData.|': ['json'], + 'data.core.array.trajectory.TrajectoryData.|': ['cif', 'json', 'xsf'], + 'data.core.array.xy.XyData.|': ['json'], + 'data.core.cif.CifData.|': ['cif'], + 'data.core.structure.StructureData.|': ['chemdoodle', 'cif', 'xsf', 'xyz'], + 'data.core.upf.UpfData.|': ['json', 'upf'], + } + response_json = response.json() + + for key, value in reference.items(): + if key not in response_json: + raise AssertionError(f'The key {key!r} is not found in the response: {response_json}') + if not set(value) <= set(response_json[key]): + raise AssertionError(f'The value {value} in key {key!r} is not contained in the response: {response_json}') + + def test_get_single_nodes(default_nodes, client): # pylint: disable=unused-argument """Test retrieving a single nodes."""