Skip to content

Commit

Permalink
Feat: Add array comparison methods. (#60)
Browse files Browse the repository at this point in the history
* feat: Improved tests for broadcasting

* feat: added support for comparison operators over nada_arrays

* chore: bump version for new release
  • Loading branch information
jcabrero authored Aug 27, 2024
1 parent dd112a0 commit d95eb5d
Show file tree
Hide file tree
Showing 17 changed files with 651 additions and 307 deletions.
119 changes: 116 additions & 3 deletions nada_numpy/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
SecretUnsignedInteger, UnsignedInteger)

from nada_numpy.context import UnsafeArithmeticSession
from nada_numpy.nada_typing import (NadaBoolean, NadaCleartextType,
NadaInteger, NadaRational,
NadaUnsignedInteger)
from nada_numpy.nada_typing import (AnyNadaType, NadaBoolean,
NadaCleartextType, NadaInteger,
NadaRational, NadaUnsignedInteger)
from nada_numpy.types import (Rational, SecretRational, fxp_abs, get_log_scale,
public_rational, rational, secret_rational, sign)
from nada_numpy.utils import copy_metadata
Expand Down Expand Up @@ -390,6 +390,119 @@ def __imatmul__(self, other: Any) -> "NadaArray":
"""
return self.matmul(other)

def __comparison_operator(
self, value: Union["NadaArray", "AnyNadaType", np.ndarray], operator: Callable
) -> "NadaArray":
"""
Perform element-wise comparison with broadcasting.
NOTE: Specially for __eq__ and __ne__ operators, the result expected is bool.
If we don't define this method, the result will be a NadaArray with bool outputs.
Args:
value (Any): The object to compare.
operator (str): The comparison operator.
Returns:
NadaArray: A new NadaArray representing the element-wise comparison result.
"""
if isinstance(value, NadaArray):
value = value.inner
if isinstance(
value,
(
SecretInteger,
Integer,
SecretUnsignedInteger,
UnsignedInteger,
SecretRational,
Rational,
),
):
return self.apply(lambda x: operator(x, value))

if isinstance(value, np.ndarray):
if len(self.inner) != len(value):
raise ValueError("Arrays must have the same length")
return NadaArray(
np.array([operator(x, y) for x, y in zip(self.inner, value)])
)

raise ValueError(f"Unsupported type: {type(value)}")

def __eq__(self, value: Any) -> "NadaArray": # type: ignore
"""
Perform equality comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean representing the element-wise equality comparison result.
"""
return self.__comparison_operator(value, lambda x, y: x == y)

def __ne__(self, value: Any) -> "NadaArray": # type: ignore
"""
Perform inequality comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean array representing the element-wise inequality comparison result.
"""
return self.__comparison_operator(value, lambda x, y: ~(x == y))

def __lt__(self, value: Any) -> "NadaArray":
"""
Perform less than comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean array representing the element-wise less than comparison result.
"""
return self.__comparison_operator(value, lambda x, y: x < y)

def __le__(self, value: Any) -> "NadaArray":
"""
Perform less than or equal comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean array representing
the element-wise less or equal thancomparison result.
"""
return self.__comparison_operator(value, lambda x, y: x <= y)

def __gt__(self, value: Any) -> "NadaArray":
"""
Perform greater than comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean array representing the element-wise greater than comparison result.
"""
return self.__comparison_operator(value, lambda x, y: x > y)

def __ge__(self, value: Any) -> "NadaArray":
"""
Perform greater than or equal comparison with broadcasting.
Args:
value (object): The object to compare.
Returns:
NadaArray: A boolean representing the element-wise greater or equal than comparison.
"""
return self.__comparison_operator(value, lambda x, y: x >= y)

def dot(self, other: "NadaArray") -> "NadaArray":
"""
Compute the dot product between two NadaArray objects.
Expand Down
1 change: 0 additions & 1 deletion nada_numpy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,7 +2929,6 @@ def _chebyshev_polynomials(x: _NadaRational, terms: int) -> np.ndarray:

# return polynomials


polynomials = [x]
y = rational(4) * x * x - rational(2)
z = y - rational(1)
Expand Down
Loading

0 comments on commit d95eb5d

Please sign in to comment.