Skip to content

Commit

Permalink
ORM: Register numpy.ndarray with the to_aiida_type to ArrayData (
Browse files Browse the repository at this point in the history
…#6149)

This will allow `numpy.ndarray` to be passed to process inputs that add
the `to_aiida_type` serializer and expect an `ArrayData`. The single
dispatch will automatically convert the numpy array to an `ArrayData`
instance.
  • Loading branch information
sphuber authored Oct 19, 2023
1 parent ec64780 commit d8dd776
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
27 changes: 16 additions & 11 deletions aiida/orm/nodes/data/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterator
from typing import Any, Iterator

from ..data import Data
from numpy import ndarray

if TYPE_CHECKING:
from numpy import ndarray
from ..base import to_aiida_type
from ..data import Data

__all__ = ('ArrayData',)


@to_aiida_type.register(ndarray)
def _(value):
return ArrayData(value)


class ArrayData(Data):
"""
Store a set of arrays on disk (rather than on the database) in an efficient
Expand All @@ -41,15 +46,15 @@ class ArrayData(Data):
array_prefix = 'array|'
default_array_name = 'default'

def __init__(self, arrays: 'ndarray' | dict[str, 'ndarray'] | None = None, **kwargs):
def __init__(self, arrays: ndarray | dict[str, ndarray] | None = None, **kwargs):
"""Construct a new instance and set one or multiple numpy arrays.
:param arrays: An optional single numpy array, or dictionary of numpy arrays to store.
"""
import numpy

super().__init__(**kwargs)
self._cached_arrays: dict[str, 'ndarray'] = {}
self._cached_arrays: dict[str, ndarray] = {}

arrays = arrays if arrays is not None else {}

Expand Down Expand Up @@ -120,7 +125,7 @@ def get_shape(self, name: str) -> tuple[int, ...]:
"""
return tuple(self.base.attributes.get(f'{self.array_prefix}{name}'))

def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]:
def get_iterarrays(self) -> Iterator[tuple[str, ndarray]]:
"""
Iterator that returns tuples (name, array) for each array stored in the node.
Expand All @@ -130,7 +135,7 @@ def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]:
for name in self.get_arraynames():
yield (name, self.get_array(name))

def get_array(self, name: str | None = None) -> 'ndarray':
def get_array(self, name: str | None = None) -> ndarray:
"""
Return an array stored in the node
Expand All @@ -152,7 +157,7 @@ def get_array(self, name: str | None = None) -> 'ndarray':

name = names[0]

def get_array_from_file(self, name: str) -> 'ndarray':
def get_array_from_file(self, name: str) -> ndarray:
"""Return the array stored in a .npy file"""
filename = f'{name}.npy'

Expand Down Expand Up @@ -182,7 +187,7 @@ def clear_internal_cache(self) -> None:
"""
self._cached_arrays = {}

def set_array(self, name: str, array: 'ndarray') -> None:
def set_array(self, name: str, array: ndarray) -> None:
"""
Store a new numpy array inside the node. Possibly overwrite the array
if it already existed.
Expand Down Expand Up @@ -271,7 +276,7 @@ def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]:
return json.dumps(json_dict).encode('utf-8'), {}


def clean_array(array: 'ndarray') -> list:
def clean_array(array: ndarray) -> list:
"""
Replacing np.nan and np.inf/-np.inf for Nones.
Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ py:class disk_objectstore.utils.LazyOpener
py:class frozenset

py:class numpy.bool_
py:class numpy.ndarray
py:class ndarray

py:class paramiko.proxy.ProxyCommand
Expand Down
7 changes: 6 additions & 1 deletion tests/orm/nodes/data/test_to_aiida_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Test the :meth:`aiida.orm.data.base.to_aiida_type` serializer."""
import numpy
import pytest

from aiida import orm
Expand All @@ -24,11 +25,15 @@
(orm.List, [0, 1, 2]),
(orm.Str, 'test-string'),
(orm.EnumData, LinkType.RETURN),
(orm.ArrayData, numpy.array([[0, 0, 0], [1, 1, 1]])),
)
)
# yapf: enable
def test_to_aiida_type(expected_type, value):
"""Test the ``to_aiida_type`` dispatch."""
converted = orm.to_aiida_type(value)
assert isinstance(converted, expected_type)
assert converted == value
if expected_type is orm.ArrayData:
assert converted.get_array().all() == value.all()
else:
assert converted == value

0 comments on commit d8dd776

Please sign in to comment.