From d8dd776a68f438702aa07b58d754b35ab0745937 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 19 Oct 2023 16:34:42 +0200 Subject: [PATCH] ORM: Register `numpy.ndarray` with the `to_aiida_type` to `ArrayData` (#6149) This will allow `numpy.ndarray` to be passed to process inputs that add the `to_aiida_type` serializer and expect an `ArrayData`. The single dispatch will automatically convert the numpy array to an `ArrayData` instance. --- aiida/orm/nodes/data/array/array.py | 27 +++++++++++++--------- docs/source/nitpick-exceptions | 1 + tests/orm/nodes/data/test_to_aiida_type.py | 7 +++++- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index 602aa3d939..cd7c0f5a0c 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -12,16 +12,21 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterator +from typing import Any, Iterator -from ..data import Data +from numpy import ndarray -if TYPE_CHECKING: - from numpy import ndarray +from ..base import to_aiida_type +from ..data import Data __all__ = ('ArrayData',) +@to_aiida_type.register(ndarray) +def _(value): + return ArrayData(value) + + class ArrayData(Data): """ Store a set of arrays on disk (rather than on the database) in an efficient @@ -41,7 +46,7 @@ class ArrayData(Data): array_prefix = 'array|' default_array_name = 'default' - def __init__(self, arrays: 'ndarray' | dict[str, 'ndarray'] | None = None, **kwargs): + def __init__(self, arrays: ndarray | dict[str, ndarray] | None = None, **kwargs): """Construct a new instance and set one or multiple numpy arrays. :param arrays: An optional single numpy array, or dictionary of numpy arrays to store. @@ -49,7 +54,7 @@ def __init__(self, arrays: 'ndarray' | dict[str, 'ndarray'] | None = None, **kwa import numpy super().__init__(**kwargs) - self._cached_arrays: dict[str, 'ndarray'] = {} + self._cached_arrays: dict[str, ndarray] = {} arrays = arrays if arrays is not None else {} @@ -120,7 +125,7 @@ def get_shape(self, name: str) -> tuple[int, ...]: """ return tuple(self.base.attributes.get(f'{self.array_prefix}{name}')) - def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]: + def get_iterarrays(self) -> Iterator[tuple[str, ndarray]]: """ Iterator that returns tuples (name, array) for each array stored in the node. @@ -130,7 +135,7 @@ def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]: for name in self.get_arraynames(): yield (name, self.get_array(name)) - def get_array(self, name: str | None = None) -> 'ndarray': + def get_array(self, name: str | None = None) -> ndarray: """ Return an array stored in the node @@ -152,7 +157,7 @@ def get_array(self, name: str | None = None) -> 'ndarray': name = names[0] - def get_array_from_file(self, name: str) -> 'ndarray': + def get_array_from_file(self, name: str) -> ndarray: """Return the array stored in a .npy file""" filename = f'{name}.npy' @@ -182,7 +187,7 @@ def clear_internal_cache(self) -> None: """ self._cached_arrays = {} - def set_array(self, name: str, array: 'ndarray') -> None: + def set_array(self, name: str, array: ndarray) -> None: """ Store a new numpy array inside the node. Possibly overwrite the array if it already existed. @@ -271,7 +276,7 @@ def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]: return json.dumps(json_dict).encode('utf-8'), {} -def clean_array(array: 'ndarray') -> list: +def clean_array(array: ndarray) -> list: """ Replacing np.nan and np.inf/-np.inf for Nones. diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index da1eebfa24..bdec739ebf 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -140,6 +140,7 @@ py:class disk_objectstore.utils.LazyOpener py:class frozenset py:class numpy.bool_ +py:class numpy.ndarray py:class ndarray py:class paramiko.proxy.ProxyCommand diff --git a/tests/orm/nodes/data/test_to_aiida_type.py b/tests/orm/nodes/data/test_to_aiida_type.py index aa76e79e98..6c106ab65f 100644 --- a/tests/orm/nodes/data/test_to_aiida_type.py +++ b/tests/orm/nodes/data/test_to_aiida_type.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test the :meth:`aiida.orm.data.base.to_aiida_type` serializer.""" +import numpy import pytest from aiida import orm @@ -24,6 +25,7 @@ (orm.List, [0, 1, 2]), (orm.Str, 'test-string'), (orm.EnumData, LinkType.RETURN), + (orm.ArrayData, numpy.array([[0, 0, 0], [1, 1, 1]])), ) ) # yapf: enable @@ -31,4 +33,7 @@ def test_to_aiida_type(expected_type, value): """Test the ``to_aiida_type`` dispatch.""" converted = orm.to_aiida_type(value) assert isinstance(converted, expected_type) - assert converted == value + if expected_type is orm.ArrayData: + assert converted.get_array().all() == value.all() + else: + assert converted == value