Skip to content

Commit

Permalink
Unit tests WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 27, 2024
1 parent 53a4ac9 commit e428fa6
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 31 deletions.
73 changes: 43 additions & 30 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,17 +804,26 @@ def size(x):
return math.prod(x.shape)


def is_writeable_array(x):
def is_writeable_array(x) -> bool:
"""
Return False if x.__setitem__ is expected to raise; True otherwise
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x):
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True


def _is_fancy_index(idx) -> bool:
if not isinstance(idx, tuple):
idx = (idx,)
return any(
isinstance(i, (list, tuple)) or is_array_api_obj(i)
for i in idx
)


_undef = object()


Expand Down Expand Up @@ -900,7 +909,7 @@ def __getitem__(self, idx):
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
raise TypeError("Index has already been set")
raise ValueError("Index has already been set")
self.idx = idx
return self

Expand All @@ -911,14 +920,12 @@ def _common(
copy: bool | None | Literal["_force_false"] = True,
**kwargs,
):
"""Validate kwargs and perform common prepocessing.
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[],
(return value, None)
Otherwise,
(None, preprocessed x)
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
raise TypeError(
Expand All @@ -929,40 +936,44 @@ def _common(
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)

x = self.x

if copy is False:
if not is_writeable_array(self.x):
raise ValueError("Cannot avoid modifying parameter in place")
if not is_writeable_array(x) or is_dask_array(x):
raise ValueError("Cannot modify parameter in place")
elif copy is None:
copy = not is_writeable_array(self.x)
copy = not is_writeable_array(x)
elif copy == "_force_false":
copy = False
elif copy is not True:
raise ValueError(f"Invalid value for copy: {copy!r}")

if copy and is_jax_array(self.x):
if is_jax_array(x):
# Use JAX's at[]
at_ = self.x.at[self.idx]
at_ = x.at[self.idx]
args = (y,) if y is not _undef else ()
return getattr(at_, at_op)(*args, **kwargs), None

# Emulate at[] behaviour for non-JAX arrays
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
x = self.x.copy() if copy else self.x
if copy:
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
xp = get_namespace(x)
x = xp.asarray(x, copy=True)

return None, x

def get(self, copy: bool | None = True, **kwargs):
"""Return x[idx]. In addition to plain __getitem__, this allows ensuring
that the output is (not) a copy and kwargs are passed to the backend."""
# Special case when xp=numpy and idx is a fancy index
# If copy is not False, avoid an unnecessary double copy.
# if copy is forced to False, raise.
if is_numpy_array(self.x) and (
isinstance(self.idx, (list, tuple))
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
):
"""
Return x[idx]. In addition to plain __getitem__, this allows ensuring
that the output is (not) a copy and kwargs are passed to the backend.
"""
# __getitem__ with a fancy index always returns a copy.
# Avoid an unnecessary double copy.
# If copy is forced to False, raise.
if _is_fancy_index(self.idx):
if copy is False:
raise ValueError(
"Indexing a numpy array with a fancy index always "
Expand Down Expand Up @@ -1032,13 +1043,15 @@ def power(self, y, /, **kwargs):

def min(self, y, /, **kwargs):
"""x[idx] = minimum(x[idx], y)"""
xp = array_namespace(self.x)
return self._iop("min", xp.minimum, y, **kwargs)
import numpy as np

return self._iop("min", np.minimum, y, **kwargs)

def max(self, y, /, **kwargs):
"""x[idx] = maximum(x[idx], y)"""
xp = array_namespace(self.x)
return self._iop("max", xp.maximum, y, **kwargs)
import numpy as np

return self._iop("max", np.maximum, y, **kwargs)


__all__ = [
Expand Down
158 changes: 158 additions & 0 deletions tests/test_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

from contextlib import contextmanager, suppress

import numpy as np
import pytest

from array_api_compat import (
array_namespace,
at,
is_dask_array,
is_jax_array,
is_torch_namespace,

Check failure on line 13 in tests/test_at.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F401)

tests/test_at.py:13:5: F401 `array_api_compat.is_torch_namespace` imported but unused
is_pydata_sparse_array,
is_writeable_array,
)
from ._helpers import import_, all_libraries


def assert_array_equal(a, b):
if is_pydata_sparse_array(a):
a = a.todense()
elif is_dask_array(a):
a = a.compute()
np.testing.assert_array_equal(a, b)


@contextmanager
def assert_copy(x, copy: bool | None):
# dask arrays are writeable, but writing to them will hot-swap the
# dask graph inside the collection so that anything that references
# the original graph, i.e. the input collection, won't be mutated.
if copy is False and (not is_writeable_array(x) or is_dask_array(x)):
with pytest.raises((TypeError, ValueError)):
yield
return

xp = array_namespace(x)
x_orig = xp.asarray(x, copy=True)
yield

if is_dask_array(x):
expect_copy = True
elif copy is None:
expect_copy = not is_writeable_array(x)
else:
expect_copy = copy
assert_array_equal((x == x_orig).all(), expect_copy)


@pytest.fixture(params=all_libraries + ["np_readonly"])
def x(request):
library = request.param
if library == "np_readonly":
x = np.asarray([10, 20, 30])
x.flags.writeable = False
else:
lib = import_(library)
x = lib.asarray([10, 20, 30])
return x


@pytest.mark.parametrize("copy", [True, False, None])
@pytest.mark.parametrize(
"op,arg,expect",
[
("apply", np.negative, [10, -20, -30]),
("set", 40, [10, 40, 40]),
("add", 40, [10, 60, 70]),
("subtract", 100, [10, -80, -70]),
("multiply", 2, [10, 40, 60]),
("divide", 3, [10, 6, 10]),
("power", 2, [10, 400, 900]),
("min", 25, [10, 20, 25]),
("max", 25, [10, 25, 30]),
],
)
def test_operations(x, copy, op, arg, expect):
with assert_copy(x, copy):
y = getattr(at(x, slice(1, None)), op)(arg, copy=copy)
assert isinstance(y, type(x))
assert_array_equal(y, expect)


@pytest.mark.parametrize("copy", [True, False, None])
def test_get(x, copy):
with assert_copy(x, copy):
y = at(x, slice(2)).get(copy=copy)
assert isinstance(y, type(x))
assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40


@pytest.mark.parametrize(
"idx",
[
[0, 1],
(0, 1),
np.array([0, 1], dtype="i1"),
np.array([0, 1], dtype="u1"),
lambda xp: xp.asarray([0, 1], dtype="i1"),
lambda xp: xp.asarray([0, 1], dtype="u1"),
[True, True, False],
(True, True, False),
np.array([True, True, False]),
lambda xp: xp.asarray([True, True, False]),
],
)
@pytest.mark.parametrize("tuple_index", [True, False])
def test_get_fancy_indices(x, idx, tuple_index):
"""get() with a fancy index always returns a copy"""
if callable(idx):
xp = array_namespace(x)
idx = idx(xp)

if is_jax_array(x) and isinstance(idx, (list, tuple)):
pytest.skip("JAX fancy indices must always be arrays")
if is_pydata_sparse_array(x) and is_pydata_sparse_array(idx):
pytest.skip("sparse fancy indices can't be sparse themselves")
if is_dask_array(x) and isinstance(idx, tuple):
pytest.skip("dask does not support tuples; only lists or arrays")
if isinstance(idx, tuple) and not tuple_index:
pytest.skip("tuple indices must always be wrapped in a tuple")

if tuple_index:
idx = (idx,)

with assert_copy(x, True):
y = at(x, idx).get()
assert isinstance(y, type(x))
assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40

with assert_copy(x, True):
y = at(x, idx).get(copy=None)
assert isinstance(y, type(x))
assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40

with pytest.raises(ValueError, match="fancy index"):
at(x, idx).get(copy=False)


def test_variant_index_syntax(x):
y = at(x)[:2].set(40)
assert isinstance(y, type(x))
assert_array_equal(y, [40, 40, 30])

with pytest.raises(ValueError):
at(x, 1)[2]
with pytest.raises(ValueError):
at(x)[1][2]
21 changes: 20 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device

from ._helpers import import_, wrapped_libraries, all_libraries

Expand Down Expand Up @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize("library", all_libraries)
def test_is_writeable_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
if is_writeable_array(x):
x[1] = 4
np.testing.assert_equal(np.asarray(x), [1, 4, 3])
else:
with pytest.raises((TypeError, ValueError)):
x[1] = 4


def test_is_writeable_array_numpy():
x = np.asarray([1, 2, 3])
assert is_writeable_array(x)
x.flags.writeable = False
assert not is_writeable_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit e428fa6

Please sign in to comment.