Skip to content

Commit

Permalink
XyData: Allow defining array(s) on construction
Browse files Browse the repository at this point in the history
Currently, the constructor does not allow to define any arrays to set
when constructing a new node, so one is forced to multi line code:

    node = XyData()
    node.set_x(np.array([1, 2]), 'name', unit')
    node.set_y(np.array([3, 4]), 'name', unit')

This commit allows initialization upon construction simplifying the code
above to:

    node = XyData(
        np.array([1, 2]),
        np.array([3, 4]),
        x_name='name',
        x_unit='unit',
        y_names='name',
        y_units='unit'
    )

The units and names are intentionally made into keyword argument only
in order to prevent accidental swapping of values.

For backwards compatibility, it remains possible to construct an
`XyData` without any arrays.
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent cec24cd commit cc87441
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
28 changes: 28 additions & 0 deletions aiida/orm/nodes/data/array/xy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ class XyData(ArrayData):
Y arrays, which can be considered functions of X.
"""

def __init__(
self,
x_array: 'ndarray' | None = None,
y_arrays: 'ndarray' | list['ndarray'] | None = None,
*,
x_name: str | None = None,
x_units: str | None = None,
y_names: str | list[str] | None = None,
y_units: str | list[str] | None = None,
**kwargs
):
"""Construct a new instance, optionally setting the x and y arrays.
.. note:: If the ``x_array`` is specified, all other keywords need to be specified as well.
:param x_array: The x array.
:param y_arrays: The y arrays.
:param x_name: The name of the x array.
:param x_units: The unit of the x array.
:param y_names: The names of the y arrays.
:param y_units: The units of the y arrays.
"""
super().__init__(**kwargs)

if x_array is not None:
self.set_x(x_array, x_name, x_units) # type: ignore[arg-type]
self.set_y(y_arrays, y_names, y_units) # type: ignore[arg-type]

@staticmethod
def _arrayandname_validator(array: 'ndarray', name: str, units: str) -> None:
"""
Expand Down
53 changes: 53 additions & 0 deletions tests/orm/nodes/data/test_xy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the :mod:`aiida.orm.nodes.data.array.xy` module."""
import numpy
import pytest

from aiida.common.exceptions import NotExistent
from aiida.orm import XyData, load_node


def test_read_stored():
"""Test reading an array from an ``XyData`` after storing and loading it."""
x_array = numpy.array([1, 2])
y_array = numpy.array([3, 4])
node = XyData(x_array, y_array, x_name='x_name', x_units='x_unit', y_names='y_name', y_units='y_units')

assert numpy.array_equal(node.get_x()[1], x_array)
assert numpy.array_equal(node.get_y()[0][1], y_array)

node.store()
assert numpy.array_equal(node.get_x()[1], x_array)
assert numpy.array_equal(node.get_y()[0][1], y_array)

loaded = load_node(node.uuid)
assert numpy.array_equal(loaded.get_x()[1], x_array)
assert numpy.array_equal(loaded.get_y()[0][1], y_array)


def test_constructor():
"""Test the various construction options."""
with pytest.raises(TypeError):
node = XyData(numpy.array([1, 2]))

node = XyData()

with pytest.raises(NotExistent):
node.get_x()

with pytest.raises(NotExistent):
node.get_y()

x_array = numpy.array([1, 2])
y_array = numpy.array([3, 4])
node = XyData(x_array, y_array, x_name='x_name', x_units='x_unit', y_names='y_name', y_units='y_units')
assert numpy.array_equal(node.get_x()[1], x_array)
assert numpy.array_equal(node.get_y()[0][1], y_array)

0 comments on commit cc87441

Please sign in to comment.