diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 30b1d852..754252e1 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,9 +1,9 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. +This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is +compatible with the Array API standard https://data-apis.org/array-api/latest/. +See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b011f08d..ef14e3b6 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,10 +7,11 @@ """ from __future__ import annotations +import operator from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Union, Any + from typing import Callable, Literal, Optional, Union, Any from ._typing import Array, Device import sys @@ -91,7 +92,7 @@ def is_cupy_array(x): import cupy as cp # TODO: Should we reject ndarray subclasses? - return isinstance(x, (cp.ndarray, cp.generic)) + return isinstance(x, cp.ndarray) def is_torch_array(x): """ @@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x return x.to_device(device, stream=stream) + def size(x): """ Return the total number of elements of x. @@ -801,6 +803,253 @@ def size(x): return None return math.prod(x.shape) + +def is_writeable_array(x) -> bool: + """ + Return False if ``x.__setitem__`` is expected to raise; True otherwise + """ + if is_numpy_array(x): + return x.flags.writeable + if is_jax_array(x) or is_pydata_sparse_array(x): + return False + return True + + +def _is_fancy_index(idx) -> bool: + if not isinstance(idx, tuple): + idx = (idx,) + return any( + isinstance(i, (list, tuple)) or is_array_api_obj(i) + for i in idx + ) + + +_undef = object() + + +class at: + """ + Update operations for read-only arrays. + + This implements ``jax.numpy.ndarray.at`` for all backends. + + Keyword arguments are passed verbatim to backends that support the `ndarray.at` + method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly + ignored for backends that don't support them. + + Additionally, this introduces support for the `copy` keyword for all backends: + + None + The array parameter *may* be modified in place if it is possible and beneficial + for performance. You should not reuse it after calling this function. + True + Ensure that the inputs are not modified. This is the default. + False + Raise ValueError if a copy cannot be avoided. + + Examples + -------- + Given either of these equivalent expressions:: + + x = at(x)[1].add(2, copy=None) + x = at(x, 1).add(2, copy=None) + + If x is a JAX array, they are the same as:: + + x = x.at[1].add(2) + + If x is a read-only numpy array, they are the same as:: + + x = x.copy() + x[1] += 2 + + Otherwise, they are the same as:: + + x[1] += 2 + + Warning + ------- + When you use copy=None, you should always immediately overwrite + the parameter array:: + + x = at(x, 0).set(2, copy=None) + + The anti-pattern below must be avoided, as it will result in different behaviour + on read-only versus writeable arrays:: + + x = xp.asarray([0, 0, 0]) + y = at(x, 0).set(2, copy=None) + z = at(x, 1).set(3, copy=None) + + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! + + Warning + ------- + The behaviour of update methods when the index is an array of integers which + contains multiple occurrences of the same index is undefined; + e.g. ``at(x, [0, 0]).set(2)`` + + Note + ---- + `sparse `_ is not supported by update methods yet. + + See Also + -------- + `jax.numpy.ndarray.at `_ + """ + + __slots__ = ("x", "idx") + + def __init__(self, x, idx=_undef, /): + self.x = x + self.idx = idx + + def __getitem__(self, idx): + """ + Allow for the alternate syntax ``at(x)[start:stop:step]``, + which looks prettier than ``at(x, slice(start, stop, step))`` + and feels more intuitive coming from the JAX documentation. + """ + if self.idx is not _undef: + raise ValueError("Index has already been set") + self.idx = idx + return self + + def _common( + self, + at_op: str, + y=_undef, + copy: bool | None | Literal["_force_false"] = True, + **kwargs, + ): + """Perform common prepocessing. + + Returns + ------- + If the operation can be resolved by at[], (return value, None) + Otherwise, (None, preprocessed x) + """ + if self.idx is _undef: + raise TypeError( + "Index has not been set.\n" + "Usage: either\n" + " at(x, idx).set(value)\n" + "or\n" + " at(x)[idx].set(value)\n" + "(same for all other methods)." + ) + + x = self.x + + if copy is False: + if not is_writeable_array(x) or is_dask_array(x): + raise ValueError("Cannot modify parameter in place") + elif copy is None: + copy = not is_writeable_array(x) + elif copy == "_force_false": + copy = False + elif copy is not True: + raise ValueError(f"Invalid value for copy: {copy!r}") + + if is_jax_array(x): + # Use JAX's at[] + at_ = x.at[self.idx] + args = (y,) if y is not _undef else () + return getattr(at_, at_op)(*args, **kwargs), None + + # Emulate at[] behaviour for non-JAX arrays + if copy: + # FIXME We blindly expect the output of x.copy() to be always writeable. + # This holds true for read-only numpy arrays, but not necessarily for + # other backends. + xp = array_namespace(x) + x = xp.asarray(x, copy=True) + + return None, x + + def get(self, **kwargs): + """ + Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring + that the output is either a copy or a view; it also allows passing + keyword arguments to the backend. + """ + # __getitem__ with a fancy index always returns a copy. + # Avoid an unnecessary double copy. + # If copy is forced to False, raise. + if _is_fancy_index(self.idx): + if kwargs.get("copy", True) is False: + raise TypeError( + "Indexing a numpy array with a fancy index always " + "results in a copy" + ) + # Skip copy inside _common, even if array is not writeable + kwargs["copy"] = "_force_false" + + res, x = self._common("get", **kwargs) + if res is not None: + return res + return x[self.idx] + + def set(self, y, /, **kwargs): + """Apply ``x[idx] = y`` and return the update array""" + res, x = self._common("set", y, **kwargs) + if res is not None: + return res + x[self.idx] = y + return x + + def _iop( + self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs + ): + """x[idx] += y or equivalent in-place operation on a subset of x + + which is the same as saying + x[idx] = x[idx] + y + Note that this is not the same as + operator.iadd(x[idx], y) + Consider for example when x is a numpy array and idx is a fancy index, which + triggers a deep copy on __getitem__. + """ + res, x = self._common(at_op, y, **kwargs) + if res is not None: + return res + x[self.idx] = elwise_op(x[self.idx], y) + return x + + def add(self, y, /, **kwargs): + """Apply ``x[idx] += y`` and return the updated array""" + return self._iop("add", operator.add, y, **kwargs) + + def subtract(self, y, /, **kwargs): + """Apply ``x[idx] -= y`` and return the updated array""" + return self._iop("subtract", operator.sub, y, **kwargs) + + def multiply(self, y, /, **kwargs): + """Apply ``x[idx] *= y`` and return the updated array""" + return self._iop("multiply", operator.mul, y, **kwargs) + + def divide(self, y, /, **kwargs): + """Apply ``x[idx] /= y`` and return the updated array""" + return self._iop("divide", operator.truediv, y, **kwargs) + + def power(self, y, /, **kwargs): + """Apply ``x[idx] **= y`` and return the updated array""" + return self._iop("power", operator.pow, y, **kwargs) + + def min(self, y, /, **kwargs): + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("min", xp.minimum, y, **kwargs) + + def max(self, y, /, **kwargs): + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("max", xp.maximum, y, **kwargs) + + __all__ = [ "array_namespace", "device", @@ -821,8 +1070,10 @@ def size(x): "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_writeable_array", "size", "to_device", + "at", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys'] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index f44dc070..e244cd5a 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -36,6 +36,8 @@ instead, which would be wrapped. .. autofunction:: device .. autofunction:: to_device .. autofunction:: size +.. autoclass:: at(array[, index]) + :members: Inspection Helpers ------------------ @@ -51,6 +53,7 @@ yet. .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array +.. autofunction:: is_writeable_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace diff --git a/tests/test_at.py b/tests/test_at.py new file mode 100644 index 00000000..3abea959 --- /dev/null +++ b/tests/test_at.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from contextlib import contextmanager, suppress + +import numpy as np +import pytest + +from array_api_compat import ( + array_namespace, + at, + is_cupy_array, + is_dask_array, + is_jax_array, + is_torch_array, + is_torch_namespace, + is_pydata_sparse_array, + is_writeable_array, +) +from ._helpers import import_, all_libraries + + +def assert_array_equal(a, b): + if is_pydata_sparse_array(a): + a = a.todense() + elif is_cupy_array(a): + a = a.get() + elif is_dask_array(a): + a = a.compute() + np.testing.assert_array_equal(a, b) + + +@contextmanager +def assert_copy(x, copy: bool | None): + # dask arrays are writeable, but writing to them will hot-swap the + # dask graph inside the collection so that anything that references + # the original graph, i.e. the input collection, won't be mutated. + if copy is False and (not is_writeable_array(x) or is_dask_array(x)): + with pytest.raises((TypeError, ValueError)): + yield + return + + xp = array_namespace(x) + x_orig = xp.asarray(x, copy=True) + yield + + expect_copy = not is_writeable_array(x) if copy is None else copy + assert_array_equal((x == x_orig).all(), expect_copy) + + +@pytest.fixture(params=all_libraries + ["np_readonly"]) +def x(request): + library = request.param + if library == "np_readonly": + x = np.asarray([10, 20, 30]) + x.flags.writeable = False + else: + lib = import_(library) + x = lib.asarray([10, 20, 30]) + return x + + +@pytest.mark.parametrize("copy", [True, False, None]) +@pytest.mark.parametrize( + "op,arg,expect", + [ + ("set", 40, [10, 40, 40]), + ("add", 40, [10, 60, 70]), + ("subtract", 100, [10, -80, -70]), + ("multiply", 2, [10, 40, 60]), + ("divide", 2, [10, 10, 15]), + ("power", 2, [10, 400, 900]), + ("min", 25, [10, 20, 25]), + ("max", 25, [10, 25, 30]), + ], +) +def test_update_ops(x, copy, op, arg, expect): + if is_pydata_sparse_array(x): + pytest.skip("at() does not support updates on sparse arrays") + + with assert_copy(x, copy): + y = getattr(at(x, slice(1, None)), op)(arg, copy=copy) + assert isinstance(y, type(x)) + assert_array_equal(y, expect) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_get(x, copy): + + expect_copy = copy + if is_dask_array(x) and copy is None: + # dask is mutable, but __getitem__ never returns a view + expect_copy = True + + with assert_copy(x, expect_copy): + y = at(x, slice(2)).get(copy=copy) + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + +@pytest.mark.parametrize( + "idx", + [ + [0, 1], + (0, 1), + np.array([0, 1], dtype="int32"), + np.array([0, 1], dtype="uint32"), + # torch only supports tensors of native integers as indices + lambda xp: xp.asarray([0, 1], dtype=None if is_torch_namespace(xp) else "int32"), + lambda xp: xp.asarray([0, 1], dtype=None if is_torch_namespace(xp) else "uint32"), + [True, True, False], + (True, True, False), + np.array([True, True, False]), + lambda xp: xp.asarray([True, True, False]), + ], +) +@pytest.mark.parametrize("tuple_index", [True, False]) +def test_get_fancy_indices(x, idx, tuple_index): + """get() with a fancy index always returns a copy""" + if callable(idx): + xp = array_namespace(x) + idx = idx(xp) + + if is_jax_array(x) and isinstance(idx, (list, tuple)): + pytest.skip("JAX fancy indices must always be arrays") + if is_pydata_sparse_array(x) and is_pydata_sparse_array(idx): + pytest.skip("sparse fancy indices can't be sparse themselves") + if is_torch_array(x) and isinstance(idx, np.ndarray) and idx.dtype.kind == "u": + pytest.skip("torch does not support unsigned integer fancy indices") + if is_dask_array(x) and isinstance(idx, tuple): + pytest.skip("dask does not support tuples; only lists or arrays") + if isinstance(idx, tuple) and not tuple_index: + pytest.skip("tuple indices must always be wrapped in a tuple") + + if tuple_index: + idx = (idx,) + + with assert_copy(x, True): + y = at(x, idx).get() + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with assert_copy(x, True): + y = at(x, idx).get(copy=None) + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + # Let assert_copy test that y is a view or copy + with suppress((TypeError, ValueError)): + y[0] = 40 + + with pytest.raises(TypeError, match="fancy index"): + at(x, idx).get(copy=False) + + +def test_variant_index_syntax(x): + y = at(x)[:2].get() + assert isinstance(y, type(x)) + assert_array_equal(y, [10, 20]) + + with pytest.raises(ValueError): + at(x, 1)[2] + with pytest.raises(ValueError): + at(x)[1][2] diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..2a74dd88 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device +from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device from ._helpers import import_, wrapped_libraries, all_libraries @@ -55,6 +55,24 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize("library", all_libraries) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)