Skip to content

Commit

Permalink
ArrayData: Make name optional in get_array
Browse files Browse the repository at this point in the history
The `ArrayData` was designed to be able to store multiple numpy arrays.
While useful, it forced users to be more verbose than necessary when
only storing a single array as an explicit array name is always required:

    node = ArrayData()
    node.set_array('some_key', numpy.array([]))
    node.get_array('some_key')

The `get_array` method is updated to allow `None` for the `name`
argument as long as the node only stores a single array so that it can
return the correct array unambiguously. This simplifies typical user
code significantly:

    node = ArrayData(numpy.array([]))
    node.get_array()
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent 35e669f commit 7fbe67c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
18 changes: 16 additions & 2 deletions aiida/orm/nodes/data/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,28 @@ def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]:
for name in self.get_arraynames():
yield (name, self.get_array(name))

def get_array(self, name: str) -> 'ndarray':
def get_array(self, name: str | None = None) -> 'ndarray':
"""
Return an array stored in the node
:param name: The name of the array to return.
:param name: The name of the array to return. The name can be omitted in case the node contains only a single
array, which will be returned in that case. If ``name`` is ``None`` and the node contains multiple arrays or
no arrays at all a ``ValueError`` is raised.
:raises ValueError: If ``name`` is ``None`` and the node contains more than one arrays or no arrays at all.
"""
import numpy

if name is None:
names = self.get_arraynames()
narrays = len(names)

if narrays == 0:
raise ValueError('`name` not specified but the node contains no arrays.')
if narrays > 1:
raise ValueError('`name` not specified but the node contains multiple arrays.')

name = names[0]

def get_array_from_file(self, name: str) -> 'ndarray':
"""Return the array stored in a .npy file"""
filename = f'{name}.npy'
Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ py:class disk_objectstore.utils.LazyOpener
py:class frozenset

py:class numpy.bool_
py:class ndarray

py:class paramiko.proxy.ProxyCommand

Expand Down
18 changes: 18 additions & 0 deletions tests/orm/nodes/data/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
###########################################################################
"""Tests for the :mod:`aiida.orm.nodes.data.array.array` module."""
import numpy
import pytest

from aiida.orm import ArrayData, load_node

Expand Down Expand Up @@ -43,3 +44,20 @@ def test_constructor():
assert sorted(node.get_arraynames()) == ['a', 'b']
assert (node.get_array('a') == arrays['a']).all()
assert (node.get_array('b') == arrays['b']).all()


def test_get_array():
"""Test :meth:`aiida.orm.nodes.data.array.array.ArrayData:get_array`."""
node = ArrayData()
with pytest.raises(ValueError, match='`name` not specified but the node contains no arrays.'):
node.get_array()

node = ArrayData({'a': numpy.array([]), 'b': numpy.array([])})
with pytest.raises(ValueError, match='`name` not specified but the node contains multiple arrays.'):
node.get_array()

node = ArrayData({'a': numpy.array([1, 2])})
assert (node.get_array() == numpy.array([1, 2])).all()

node = ArrayData(numpy.array([1, 2]))
assert (node.get_array() == numpy.array([1, 2])).all()

0 comments on commit 7fbe67c

Please sign in to comment.