diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 4e6c94ee..a9892452 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -804,17 +804,26 @@ def size(x): return math.prod(x.shape) -def is_writeable_array(x): +def is_writeable_array(x) -> bool: """ Return False if x.__setitem__ is expected to raise; True otherwise """ if is_numpy_array(x): return x.flags.writeable - if is_jax_array(x): + if is_jax_array(x) or is_pydata_sparse_array(x): return False return True +def _is_fancy_index(idx) -> bool: + if not isinstance(idx, tuple): + idx = (idx,) + return any( + isinstance(i, (list, tuple)) or is_array_api_obj(i) + for i in idx + ) + + _undef = object() @@ -900,7 +909,7 @@ def __getitem__(self, idx): and feels more intuitive coming from the JAX documentation. """ if self.idx is not _undef: - raise TypeError("Index has already been set") + raise ValueError("Index has already been set") self.idx = idx return self @@ -911,14 +920,12 @@ def _common( copy: bool | None | Literal["_force_false"] = True, **kwargs, ): - """Validate kwargs and perform common prepocessing. + """Perform common prepocessing. Returns ------- - If the operation can be resolved by at[], - (return value, None) - Otherwise, - (None, preprocessed x) + If the operation can be resolved by at[], (return value, None) + Otherwise, (None, preprocessed x) """ if self.idx is _undef: raise TypeError( @@ -929,40 +936,44 @@ def _common( " at(x)[idx].set(value)\n" "(same for all other methods)." ) + + x = self.x if copy is False: - if not is_writeable_array(self.x): - raise ValueError("Cannot avoid modifying parameter in place") + if not is_writeable_array(x) or is_dask_array(x): + raise ValueError("Cannot modify parameter in place") elif copy is None: - copy = not is_writeable_array(self.x) + copy = not is_writeable_array(x) elif copy == "_force_false": copy = False elif copy is not True: raise ValueError(f"Invalid value for copy: {copy!r}") - if copy and is_jax_array(self.x): + if is_jax_array(x): # Use JAX's at[] - at_ = self.x.at[self.idx] + at_ = x.at[self.idx] args = (y,) if y is not _undef else () return getattr(at_, at_op)(*args, **kwargs), None # Emulate at[] behaviour for non-JAX arrays - # FIXME We blindly expect the output of x.copy() to be always writeable. - # This holds true for read-only numpy arrays, but not necessarily for - # other backends. - x = self.x.copy() if copy else self.x + if copy: + # FIXME We blindly expect the output of x.copy() to be always writeable. + # This holds true for read-only numpy arrays, but not necessarily for + # other backends. + xp = get_namespace(x) + x = xp.asarray(x, copy=True) + return None, x def get(self, copy: bool | None = True, **kwargs): - """Return x[idx]. In addition to plain __getitem__, this allows ensuring - that the output is (not) a copy and kwargs are passed to the backend.""" - # Special case when xp=numpy and idx is a fancy index - # If copy is not False, avoid an unnecessary double copy. - # if copy is forced to False, raise. - if is_numpy_array(self.x) and ( - isinstance(self.idx, (list, tuple)) - or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu") - ): + """ + Return x[idx]. In addition to plain __getitem__, this allows ensuring + that the output is (not) a copy and kwargs are passed to the backend. + """ + # __getitem__ with a fancy index always returns a copy. + # Avoid an unnecessary double copy. + # If copy is forced to False, raise. + if _is_fancy_index(self.idx): if copy is False: raise ValueError( "Indexing a numpy array with a fancy index always " @@ -1032,13 +1043,15 @@ def power(self, y, /, **kwargs): def min(self, y, /, **kwargs): """x[idx] = minimum(x[idx], y)""" - xp = array_namespace(self.x) - return self._iop("min", xp.minimum, y, **kwargs) + import numpy as np + + return self._iop("min", np.minimum, y, **kwargs) def max(self, y, /, **kwargs): """x[idx] = maximum(x[idx], y)""" - xp = array_namespace(self.x) - return self._iop("max", xp.maximum, y, **kwargs) + import numpy as np + + return self._iop("max", np.maximum, y, **kwargs) __all__ = [ diff --git a/tests/test_at.py b/tests/test_at.py new file mode 100644 index 00000000..f1732ee7 --- /dev/null +++ b/tests/test_at.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from contextlib import contextmanager, suppress + +import numpy as np +import pytest + +from array_api_compat import ( + array_namespace, + at, + is_dask_array, + is_jax_array, + is_torch_namespace, + is_pydata_sparse_array, + is_writeable_array, +) +from ._helpers import import_, all_libraries + + +def assert_array_equal(a, b): + if is_pydata_sparse_array(a): + a = a.todense() + elif is_dask_array(a): + a = a.compute() + np.testing.assert_array_equal(a, b) + + +@contextmanager +def assert_copy(x, copy: bool | None): + # dask arrays are writeable, but writing to them will hot-swap the + # dask graph inside the collection so that anything that references + # the original graph, i.e. the input collection, won't be mutated. + if copy is False and (not is_writeable_array(x) or is_dask_array(x)): + with pytest.raises((TypeError, ValueError)): + yield + return + + xp = array_namespace(x) + x_orig = xp.asarray(x, copy=True) + yield + + if is_dask_array(x): + expect_copy = True + elif copy is None: + expect_copy = not is_writeable_array(x) + else: + expect_copy = copy + assert_array_equal((x == x_orig).all(), expect_copy) + + +@pytest.fixture(params=all_libraries + ["np_readonly"]) +def x(request): + library = request.param + if library == "np_readonly": + x = np.asarray([10, 20, 30]) + x.flags.writeable = False + else: + lib = import_(library) + x = lib.asarray([10, 20, 30]) + return x + + +@pytest.mark.parametrize("copy", [True, False, None]) +@pytest.mark.parametrize( + "op,arg,expect", + [ + ("apply", np.negative, [10, -20, -30]), + ("set", 40, [10, 40, 40]), + ("add", 40, [10, 60, 70]), + ("subtract", 100, [10, -80, -70]), + ("multiply", 2, [10, 40, 60]), + ("divide", 3, [10, 6, 10]), + ("power", 2, [10, 400, 900]), + ("min", 25, [10, 20, 25]), + ("max", 25, [10, 25, 30]), + ], +) +def test_operations(x, copy, op, arg, expect): + with assert_copy(x, copy): + y = getattr(at(x, slice(1, None)), op)(arg, copy=copy) + assert isinstance(y, type(x)) + assert_array_equal(y, expect) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_get(x, copy): + with assert_copy(x, copy): + y = at(x, slice(2)).get(copy=copy) + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + +@pytest.mark.parametrize( + "idx", + [ + [0, 1], + (0, 1), + np.array([0, 1], dtype="i1"), + np.array([0, 1], dtype="u1"), + lambda xp: xp.asarray([0, 1], dtype="i1"), + lambda xp: xp.asarray([0, 1], dtype="u1"), + [True, True, False], + (True, True, False), + np.array([True, True, False]), + lambda xp: xp.asarray([True, True, False]), + ], +) +@pytest.mark.parametrize("tuple_index", [True, False]) +def test_get_fancy_indices(x, idx, tuple_index): + """get() with a fancy index always returns a copy""" + if callable(idx): + xp = array_namespace(x) + idx = idx(xp) + + if is_jax_array(x) and isinstance(idx, (list, tuple)): + pytest.skip("JAX fancy indices must always be arrays") + if is_pydata_sparse_array(x) and is_pydata_sparse_array(idx): + pytest.skip("sparse fancy indices can't be sparse themselves") + if is_dask_array(x) and isinstance(idx, tuple): + pytest.skip("dask does not support tuples; only lists or arrays") + if isinstance(idx, tuple) and not tuple_index: + pytest.skip("tuple indices must always be wrapped in a tuple") + + if tuple_index: + idx = (idx,) + + with assert_copy(x, True): + y = at(x, idx).get() + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with assert_copy(x, True): + y = at(x, idx).get(copy=None) + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with pytest.raises(ValueError, match="fancy index"): + at(x, idx).get(copy=False) + + +def test_variant_index_syntax(x): + y = at(x)[:2].set(40) + assert isinstance(y, type(x)) + assert_array_equal(y, [40, 40, 30]) + + with pytest.raises(ValueError): + at(x, 1)[2] + with pytest.raises(ValueError): + at(x)[1][2] diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..955bcb48 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device +from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device from ._helpers import import_, wrapped_libraries, all_libraries @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize("library", all_libraries) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + np.testing.assert_equal(np.asarray(x), [1, 4, 3]) + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)