Skip to content

Commit

Permalink
Unit tests WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 27, 2024
1 parent 53a4ac9 commit a1f1b0f
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 28 deletions.
61 changes: 34 additions & 27 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def is_writeable_array(x):
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x):
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True

Expand Down Expand Up @@ -900,7 +900,7 @@ def __getitem__(self, idx):
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
raise TypeError("Index has already been set")
raise ValueError("Index has already been set")
self.idx = idx
return self

Expand All @@ -911,14 +911,12 @@ def _common(
copy: bool | None | Literal["_force_false"] = True,
**kwargs,
):
"""Validate kwargs and perform common prepocessing.
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[],
(return value, None)
Otherwise,
(None, preprocessed x)
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
raise TypeError(
Expand All @@ -929,39 +927,46 @@ def _common(
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)

x = self.x

if copy is False:
if not is_writeable_array(self.x):
raise ValueError("Cannot avoid modifying parameter in place")
if not is_writeable_array(x) or is_dask_array(x):
raise ValueError("Cannot modify parameter in place")
elif copy is None:
copy = not is_writeable_array(self.x)
copy = not is_writeable_array(x)
elif copy == "_force_false":
copy = False
elif copy is not True:
raise ValueError(f"Invalid value for copy: {copy!r}")

if copy and is_jax_array(self.x):
if is_jax_array(x):
# Use JAX's at[]
at_ = self.x.at[self.idx]
at_ = x.at[self.idx]
args = (y,) if y is not _undef else ()
return getattr(at_, at_op)(*args, **kwargs), None

# Emulate at[] behaviour for non-JAX arrays
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
x = self.x.copy() if copy else self.x
if copy:
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
xp = get_namespace(x)
x = xp.asarray(x, copy=True)

return None, x

def get(self, copy: bool | None = True, **kwargs):
"""Return x[idx]. In addition to plain __getitem__, this allows ensuring
that the output is (not) a copy and kwargs are passed to the backend."""
# Special case when xp=numpy and idx is a fancy index
# If copy is not False, avoid an unnecessary double copy.
# if copy is forced to False, raise.
if is_numpy_array(self.x) and (
"""
Return x[idx]. In addition to plain __getitem__, this allows ensuring
that the output is (not) a copy and kwargs are passed to the backend.
"""
# __getitem__ with a fancy index always returns a copy.
# Avoid an unnecessary double copy.
# If copy is forced to False, raise.
if (
isinstance(self.idx, (list, tuple))
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
or (is_array_api_obj(self.idx) and self.idx.dtype.kind in "biu")
):
if copy is False:
raise ValueError(
Expand Down Expand Up @@ -1032,13 +1037,15 @@ def power(self, y, /, **kwargs):

def min(self, y, /, **kwargs):
"""x[idx] = minimum(x[idx], y)"""
xp = array_namespace(self.x)
return self._iop("min", xp.minimum, y, **kwargs)
import numpy as np

return self._iop("min", np.minimum, y, **kwargs)

def max(self, y, /, **kwargs):
"""x[idx] = maximum(x[idx], y)"""
xp = array_namespace(self.x)
return self._iop("max", xp.maximum, y, **kwargs)
import numpy as np

return self._iop("max", np.maximum, y, **kwargs)


__all__ = [
Expand Down
126 changes: 126 additions & 0 deletions tests/test_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from contextlib import contextmanager, suppress

import numpy as np
import pytest

from array_api_compat import array_namespace, at, is_dask_array, is_jax_array, is_writeable_array
from ._helpers import import_, all_libraries


@contextmanager
def assert_copy(x, copy: bool | None):
# dask arrays are writeable, but writing to them will hot-swap the
# dask graph inside the collection so that anything that references
# the original graph, i.e. the input collection, won't be mutated.
if copy is False and (not is_writeable_array(x) or is_dask_array(x)):
with pytest.raises((TypeError, ValueError)):
yield
return

xp = array_namespace(x)
x_orig = xp.asarray(x, copy=True)
yield

expect_copy = (
copy if copy is not None else (not is_writeable_array(x) or is_dask_array(x))
)
np.testing.assert_array_equal((x == x_orig).all(), expect_copy)


@pytest.fixture(params=all_libraries + ["np_readonly"])
def x(request):
library = request.param
if library == "np_readonly":
x = np.asarray([10, 20, 30])
x.flags.writeable = False
else:
lib = import_(library)
x = lib.asarray([10, 20, 30])
return x


@pytest.mark.parametrize("copy", [True, False, None])
@pytest.mark.parametrize(
"op,arg,expect",
[
("apply", np.negative, [10, -20, -30]),
("set", 40, [10, 40, 40]),
("add", 40, [10, 60, 70]),
("subtract", 100, [10, -80, -70]),
("multiply", 2, [10, 40, 60]),
("divide", 3, [10, 6, 10]),
("power", 2, [10, 400, 900]),
("min", 25, [10, 20, 25]),
("max", 25, [10, 25, 30]),
],
)
def test_operations(x, copy, op, arg, expect):
with assert_copy(x, copy):
y = getattr(at(x, slice(1, None)), op)(arg, copy=copy)
assert isinstance(y, type(x))
np.testing.assert_equal(y, expect)


@pytest.mark.parametrize("copy", [True, False, None])
def test_get(x, copy):
with assert_copy(x, copy):
y = at(x, slice(2)).get(copy=copy)
assert isinstance(y, type(x))
np.testing.assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40


@pytest.mark.parametrize(
"idx",
[
[0, 1],
((0, 1), ),
np.array([0, 1], dtype="i1"),
np.array([0, 1], dtype="u1"),
[True, True, False],
(True, True, False),
np.array([True, True, False]),
],
)
@pytest.mark.parametrize("wrap_index", [True, False])
def test_get_fancy_indices(x, idx, wrap_index):
"""get() with a fancy index always returns a copy"""
if not wrap_index and is_jax_array(x) and isinstance(idx, (list, tuple)):
pytest.skip("JAX fancy indices must always be arrays")

if wrap_index:
xp = array_namespace(x)
idx = xp.asarray(idx)

with assert_copy(x, True):
y = at(x, [0, 1]).get()
assert isinstance(y, type(x))
np.testing.assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40

with assert_copy(x, True):
y = at(x, [0, 1]).get(copy=None)
assert isinstance(y, type(x))
np.testing.assert_array_equal(y, [10, 20])
# Let assert_copy test that y is a view or copy
with suppress((TypeError, ValueError)):
y[0] = 40

with pytest.raises(ValueError, match="fancy index"):
at(x, [0, 1]).get(copy=False)


@pytest.mark.parametrize("copy", [True, False, None])
def test_variant_index_syntax(x, copy):
with assert_copy(x, copy):
y = at(x)[:2].set(40, copy=copy)
assert isinstance(y, type(x))
np.testing.assert_array_equal(y, [40, 40, 30])
with pytest.raises(ValueError):
at(x, 1)[2]
with pytest.raises(ValueError):
at(x)[1][2]
21 changes: 20 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device

from ._helpers import import_, wrapped_libraries, all_libraries

Expand Down Expand Up @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize("library", all_libraries)
def test_is_writeable_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
if is_writeable_array(x):
x[1] = 4
np.testing.assert_equal(np.asarray(x), [1, 4, 3])
else:
with pytest.raises((TypeError, ValueError)):
x[1] = 4


def test_is_writeable_array_numpy():
x = np.asarray([1, 2, 3])
assert is_writeable_array(x)
x.flags.writeable = False
assert not is_writeable_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit a1f1b0f

Please sign in to comment.