diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index 93177a6cea..05907b3dd0 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -50,6 +50,34 @@ class XyData(ArrayData): Y arrays, which can be considered functions of X. """ + def __init__( + self, + x_array: 'ndarray' | None = None, + y_arrays: 'ndarray' | list['ndarray'] | None = None, + *, + x_name: str | None = None, + x_units: str | None = None, + y_names: str | list[str] | None = None, + y_units: str | list[str] | None = None, + **kwargs + ): + """Construct a new instance, optionally setting the x and y arrays. + + .. note:: If the ``x_array`` is specified, all other keywords need to be specified as well. + + :param x_array: The x array. + :param y_arrays: The y arrays. + :param x_name: The name of the x array. + :param x_units: The unit of the x array. + :param y_names: The names of the y arrays. + :param y_units: The units of the y arrays. + """ + super().__init__(**kwargs) + + if x_array is not None: + self.set_x(x_array, x_name, x_units) # type: ignore[arg-type] + self.set_y(y_arrays, y_names, y_units) # type: ignore[arg-type] + @staticmethod def _arrayandname_validator(array: 'ndarray', name: str, units: str) -> None: """ diff --git a/tests/orm/nodes/data/test_xy.py b/tests/orm/nodes/data/test_xy.py new file mode 100644 index 0000000000..9f1adb2e97 --- /dev/null +++ b/tests/orm/nodes/data/test_xy.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the :mod:`aiida.orm.nodes.data.array.xy` module.""" +import numpy +import pytest + +from aiida.common.exceptions import NotExistent +from aiida.orm import XyData, load_node + + +def test_read_stored(): + """Test reading an array from an ``XyData`` after storing and loading it.""" + x_array = numpy.array([1, 2]) + y_array = numpy.array([3, 4]) + node = XyData(x_array, y_array, x_name='x_name', x_units='x_unit', y_names='y_name', y_units='y_units') + + assert numpy.array_equal(node.get_x()[1], x_array) + assert numpy.array_equal(node.get_y()[0][1], y_array) + + node.store() + assert numpy.array_equal(node.get_x()[1], x_array) + assert numpy.array_equal(node.get_y()[0][1], y_array) + + loaded = load_node(node.uuid) + assert numpy.array_equal(loaded.get_x()[1], x_array) + assert numpy.array_equal(loaded.get_y()[0][1], y_array) + + +def test_constructor(): + """Test the various construction options.""" + with pytest.raises(TypeError): + node = XyData(numpy.array([1, 2])) + + node = XyData() + + with pytest.raises(NotExistent): + node.get_x() + + with pytest.raises(NotExistent): + node.get_y() + + x_array = numpy.array([1, 2]) + y_array = numpy.array([3, 4]) + node = XyData(x_array, y_array, x_name='x_name', x_units='x_unit', y_names='y_name', y_units='y_units') + assert numpy.array_equal(node.get_x()[1], x_array) + assert numpy.array_equal(node.get_y()[0][1], y_array)