Skip to content

Commit

Permalink
Add type hinting for aiida.orm.nodes.data.array.xy
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent 7fbe67c commit cec24cd
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions aiida/orm/nodes/data/array/xy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
collections of y-arrays bound to a single x-array, and the methods to operate
on them.
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Sequence

import numpy as np

from aiida.common.exceptions import NotExistent

from .array import ArrayData

if TYPE_CHECKING:
from numpy import ndarray

__all__ = ('XyData',)


def check_convert_single_to_tuple(item):
def check_convert_single_to_tuple(item: Any | Sequence[Any]) -> Sequence[Any]:
"""
Checks if the item is a list or tuple, and converts it to a list if it is
not already a list or tuple
Expand All @@ -44,7 +51,7 @@ class XyData(ArrayData):
"""

@staticmethod
def _arrayandname_validator(array, name, units):
def _arrayandname_validator(array: 'ndarray', name: str, units: str) -> None:
"""
Validates that the array is an numpy.ndarray and that the name is
of type str. Raises TypeError or ValueError if this not the case.
Expand All @@ -61,7 +68,7 @@ def _arrayandname_validator(array, name, units):
if not isinstance(units, str):
raise TypeError('The units must always be a str.')

def set_x(self, x_array, x_name, x_units):
def set_x(self, x_array: 'ndarray', x_name: str, x_units: str) -> None:
"""
Sets the array and the name for the x values.
Expand All @@ -74,7 +81,9 @@ def set_x(self, x_array, x_name, x_units):
self.base.attributes.set('x_units', x_units)
self.set_array('x_array', x_array)

def set_y(self, y_arrays, y_names, y_units):
def set_y(
self, y_arrays: 'ndarray' | Sequence['ndarray'], y_names: str | Sequence[str], y_units: str | Sequence[str]
) -> None:
"""
Set array(s) for the y part of the dataset. Also checks if the
x_array has already been set, and that, the shape of the y_arrays
Expand Down Expand Up @@ -110,7 +119,7 @@ def set_y(self, y_arrays, y_names, y_units):
self.base.attributes.set('y_names', y_names)
self.base.attributes.set('y_units', y_units)

def get_x(self):
def get_x(self) -> tuple[str, 'ndarray', str]:
"""
Tries to retrieve the x array and x name raises a NotExistent
exception if no x array has been set yet.
Expand All @@ -126,7 +135,7 @@ def get_x(self):
raise NotExistent('No x array has been set yet!')
return x_name, x_array, x_units

def get_y(self):
def get_y(self) -> list[tuple[str, 'ndarray', str]]:
"""
Tries to retrieve the y arrays and the y names, raises a
NotExistent exception if they have not been set yet, or cannot be
Expand Down

0 comments on commit cec24cd

Please sign in to comment.