diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 4e6c94ee..e9f40089 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -900,7 +900,7 @@ def __getitem__(self, idx): and feels more intuitive coming from the JAX documentation. """ if self.idx is not _undef: - raise TypeError("Index has already been set") + raise ValueError("Index has already been set") self.idx = idx return self @@ -911,14 +911,12 @@ def _common( copy: bool | None | Literal["_force_false"] = True, **kwargs, ): - """Validate kwargs and perform common prepocessing. + """Perform common prepocessing. Returns ------- - If the operation can be resolved by at[], - (return value, None) - Otherwise, - (None, preprocessed x) + If the operation can be resolved by at[], (return value, None) + Otherwise, (None, preprocessed x) """ if self.idx is _undef: raise TypeError( @@ -929,39 +927,46 @@ def _common( " at(x)[idx].set(value)\n" "(same for all other methods)." ) + + x = self.x if copy is False: - if not is_writeable_array(self.x): - raise ValueError("Cannot avoid modifying parameter in place") + if not is_writeable_array(x) or is_dask_array(x): + raise ValueError("Cannot modify parameter in place") elif copy is None: - copy = not is_writeable_array(self.x) + copy = not is_writeable_array(x) elif copy == "_force_false": copy = False elif copy is not True: raise ValueError(f"Invalid value for copy: {copy!r}") - if copy and is_jax_array(self.x): + if is_jax_array(x): # Use JAX's at[] - at_ = self.x.at[self.idx] + at_ = x.at[self.idx] args = (y,) if y is not _undef else () return getattr(at_, at_op)(*args, **kwargs), None # Emulate at[] behaviour for non-JAX arrays - # FIXME We blindly expect the output of x.copy() to be always writeable. - # This holds true for read-only numpy arrays, but not necessarily for - # other backends. - x = self.x.copy() if copy else self.x + if copy: + # FIXME We blindly expect the output of x.copy() to be always writeable. + # This holds true for read-only numpy arrays, but not necessarily for + # other backends. + xp = get_namespace(x) + x = xp.asarray(x, copy=True) + return None, x def get(self, copy: bool | None = True, **kwargs): - """Return x[idx]. In addition to plain __getitem__, this allows ensuring - that the output is (not) a copy and kwargs are passed to the backend.""" - # Special case when xp=numpy and idx is a fancy index - # If copy is not False, avoid an unnecessary double copy. - # if copy is forced to False, raise. - if is_numpy_array(self.x) and ( + """ + Return x[idx]. In addition to plain __getitem__, this allows ensuring + that the output is (not) a copy and kwargs are passed to the backend. + """ + # __getitem__ with a fancy index always returns a copy. + # Avoid an unnecessary double copy. + # If copy is forced to False, raise. + if ( isinstance(self.idx, (list, tuple)) - or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu") + or (is_array_api_obj(self.idx) and self.idx.dtype.kind in "biu") ): if copy is False: raise ValueError( @@ -1032,13 +1037,15 @@ def power(self, y, /, **kwargs): def min(self, y, /, **kwargs): """x[idx] = minimum(x[idx], y)""" - xp = array_namespace(self.x) - return self._iop("min", xp.minimum, y, **kwargs) + import numpy as np + + return self._iop("min", np.minimum, y, **kwargs) def max(self, y, /, **kwargs): """x[idx] = maximum(x[idx], y)""" - xp = array_namespace(self.x) - return self._iop("max", xp.maximum, y, **kwargs) + import numpy as np + + return self._iop("max", np.maximum, y, **kwargs) __all__ = [ diff --git a/tests/test_at.py b/tests/test_at.py new file mode 100644 index 00000000..6ed5dbee --- /dev/null +++ b/tests/test_at.py @@ -0,0 +1,128 @@ +from contextlib import contextmanager, suppress + +import numpy as np +import pytest + +from array_api_compat import array_namespace, at, is_dask_array, is_jax_array, is_writeable_array +from ._helpers import import_, all_libraries + + +@contextmanager +def assert_copy(x, copy: bool | None): + # dask arrays are writeable, but writing to them will hot-swap the + # dask graph inside the collection so that anything that references + # the original graph, i.e. the input collection, won't be mutated. + if copy is False and (not is_writeable_array(x) or is_dask_array(x)): + with pytest.raises((TypeError, ValueError)): + yield + return + + xp = array_namespace(x) + x_orig = xp.asarray(x, copy=True) + yield + + expect_copy = ( + copy if copy is not None else (not is_writeable_array(x) or is_dask_array(x)) + ) + np.testing.assert_array_equal((x == x_orig).all(), expect_copy) + + +@pytest.fixture(params=all_libraries + ["np_readonly"]) +def x(request): + library = request.param + if library == "np_readonly": + x = np.asarray([10, 20, 30]) + x.flags.writeable = False + else: + lib = import_(library) + x = lib.asarray([10, 20, 30]) + return x + + +@pytest.mark.parametrize("copy", [True, False, None]) +@pytest.mark.parametrize( + "op,arg,expect", + [ + ("apply", np.negative, [10, -20, 30]), + ("set", 40, [10, 40, 30]), + ("add", 40, [10, 60, 30]), + ("subtract", 100, [10, -80, 30]), + ("multiply", 2, [10, 40, 30]), + ("divide", 3, [10, 6, 30]), + ("power", 2, [10, 400, 30]), + ("min", 15, [10, 15, 30]), + ("min", 25, [10, 20, 30]), + ("max", 15, [10, 20, 30]), + ("max", 25, [10, 25, 30]), + ], +) +def test_operations(x, copy, op, arg, expect): + with assert_copy(x, copy): + y = getattr(at(x, 1), op)(arg, copy=copy) + assert isinstance(y, type(x)) + np.testing.assert_equal(y, expect) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_get(x, copy): + with assert_copy(x, copy): + y = at(x, slice(2)).get(copy=copy) + assert isinstance(y, type(x)) + np.testing.assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + +@pytest.mark.parametrize( + "idx", + [ + [0, 1], + ((0, 1), ), + np.array([0, 1], dtype="i1"), + np.array([0, 1], dtype="u1"), + [True, True, False], + (True, True, False), + np.array([True, True, False]), + ], +) +@pytest.mark.parametrize("wrap_index", [True, False]) +def test_get_fancy_indices(x, idx, wrap_index): + """get() with a fancy index always returns a copy""" + if not wrap_index and is_jax_array(x) and isinstance(idx, (list, tuple)): + pytest.skip("JAX fancy indices must always be arrays") + + if wrap_index: + xp = array_namespace(x) + idx = xp.asarray(idx) + + with assert_copy(x, True): + y = at(x, [0, 1]).get() + assert isinstance(y, type(x)) + np.testing.assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with assert_copy(x, True): + y = at(x, [0, 1]).get(copy=None) + assert isinstance(y, type(x)) + np.testing.assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with pytest.raises(ValueError, match="fancy index"): + at(x, [0, 1]).get(copy=False) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_variant_index_syntax(x, copy): + with assert_copy(x, copy): + y = at(x)[:2].set(40, copy=copy) + assert isinstance(y, type(x)) + np.testing.assert_array_equal(y, [40, 40, 30]) + with pytest.raises(ValueError): + at(x, 1)[2] + with pytest.raises(ValueError): + at(x)[1][2] diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..955bcb48 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device +from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device from ._helpers import import_, wrapped_libraries, all_libraries @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize("library", all_libraries) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + np.testing.assert_equal(np.asarray(x), [1, 4, 3]) + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)