diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py
index 30b1d852..754252e1 100644
--- a/array_api_compat/__init__.py
+++ b/array_api_compat/__init__.py
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
-This is a small wrapper around NumPy and CuPy that is compatible with the
-Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
-https://numpy.org/neps/nep-0047-array-api-standard.html.
+This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
+compatible with the Array API standard https://data-apis.org/array-api/latest/.
+See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index b011f08d..cbc0505f 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -7,10 +7,11 @@
"""
from __future__ import annotations
+import operator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from typing import Optional, Union, Any
+ from typing import Callable, Literal, Optional, Union, Any
from ._typing import Array, Device
import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
import cupy as cp
# TODO: Should we reject ndarray subclasses?
- return isinstance(x, (cp.ndarray, cp.generic))
+ return isinstance(x, cp.ndarray)
def is_torch_array(x):
"""
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x
return x.to_device(device, stream=stream)
+
def size(x):
"""
Return the total number of elements of x.
@@ -801,6 +803,253 @@ def size(x):
return None
return math.prod(x.shape)
+
+def is_writeable_array(x) -> bool:
+ """
+ Return False if ``x.__setitem__`` is expected to raise; True otherwise
+ """
+ if is_numpy_array(x):
+ return x.flags.writeable
+ if is_jax_array(x) or is_pydata_sparse_array(x):
+ return False
+ return True
+
+
+def _is_fancy_index(idx) -> bool:
+ if not isinstance(idx, tuple):
+ idx = (idx,)
+ return any(
+ isinstance(i, (list, tuple)) or is_array_api_obj(i)
+ for i in idx
+ )
+
+
+_undef = object()
+
+
+class at:
+ """
+ Update operations for read-only arrays.
+
+ This implements ``jax.numpy.ndarray.at`` for all backends.
+
+ Keyword arguments are passed verbatim to backends that support the `ndarray.at`
+ method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
+ ignored for backends that don't support them.
+
+ Additionally, this introduces support for the `copy` keyword for all backends:
+
+ None
+ The array parameter *may* be modified in place if it is possible and beneficial
+ for performance. You should not reuse it after calling this function.
+ True
+ Ensure that the inputs are not modified. This is the default.
+ False
+ Raise ValueError if a copy cannot be avoided.
+
+ Examples
+ --------
+ Given either of these equivalent expressions::
+
+ x = at(x)[1].add(2, copy=None)
+ x = at(x, 1).add(2, copy=None)
+
+ If x is a JAX array, they are the same as::
+
+ x = x.at[1].add(2)
+
+ If x is a read-only numpy array, they are the same as::
+
+ x = x.copy()
+ x[1] += 2
+
+ Otherwise, they are the same as::
+
+ x[1] += 2
+
+ Warning
+ -------
+ When you use copy=None, you should always immediately overwrite
+ the parameter array::
+
+ x = at(x, 0).set(2, copy=None)
+
+ The anti-pattern below must be avoided, as it will result in different behaviour
+ on read-only versus writeable arrays::
+
+ x = xp.asarray([0, 0, 0])
+ y = at(x, 0).set(2, copy=None)
+ z = at(x, 1).set(3, copy=None)
+
+ In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
+ when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
+
+ Warning
+ -------
+ The behaviour of update methods when the index is an array of integers which
+ contains multiple occurrences of the same index is undefined;
+ e.g. ``at(x, [0, 0]).set(2)``
+
+ Note
+ ----
+ `sparse `_ is not supported by update methods yet.
+
+ See Also
+ --------
+ `jax.numpy.ndarray.at `_
+ """
+
+ __slots__ = ("x", "idx")
+
+ def __init__(self, x, idx=_undef, /):
+ self.x = x
+ self.idx = idx
+
+ def __getitem__(self, idx):
+ """
+ Allow for the alternate syntax ``at(x)[start:stop:step]``,
+ which looks prettier than ``at(x, slice(start, stop, step))``
+ and feels more intuitive coming from the JAX documentation.
+ """
+ if self.idx is not _undef:
+ raise ValueError("Index has already been set")
+ self.idx = idx
+ return self
+
+ def _common(
+ self,
+ at_op: str,
+ y=_undef,
+ copy: bool | None | Literal["_force_false"] = True,
+ **kwargs,
+ ):
+ """Perform common prepocessing.
+
+ Returns
+ -------
+ If the operation can be resolved by at[], (return value, None)
+ Otherwise, (None, preprocessed x)
+ """
+ if self.idx is _undef:
+ raise TypeError(
+ "Index has not been set.\n"
+ "Usage: either\n"
+ " at(x, idx).set(value)\n"
+ "or\n"
+ " at(x)[idx].set(value)\n"
+ "(same for all other methods)."
+ )
+
+ x = self.x
+
+ if copy is False:
+ if not is_writeable_array(x) or is_dask_array(x):
+ raise ValueError("Cannot modify parameter in place")
+ elif copy is None:
+ copy = not is_writeable_array(x)
+ elif copy == "_force_false":
+ copy = False
+ elif copy is not True:
+ raise ValueError(f"Invalid value for copy: {copy!r}")
+
+ if is_jax_array(x):
+ # Use JAX's at[]
+ at_ = x.at[self.idx]
+ args = (y,) if y is not _undef else ()
+ return getattr(at_, at_op)(*args, **kwargs), None
+
+ # Emulate at[] behaviour for non-JAX arrays
+ if copy:
+ # FIXME We blindly expect the output of x.copy() to be always writeable.
+ # This holds true for read-only numpy arrays, but not necessarily for
+ # other backends.
+ xp = array_namespace(x)
+ x = xp.asarray(x, copy=True)
+
+ return None, x
+
+ def get(self, **kwargs):
+ """
+ Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
+ that the output is either a copy or a view; it also allows passing
+ keyword arguments to the backend.
+ """
+ # __getitem__ with a fancy index always returns a copy.
+ # Avoid an unnecessary double copy.
+ # If copy is forced to False, raise.
+ if _is_fancy_index(self.idx):
+ if kwargs.get("copy", True) is False:
+ raise TypeError(
+ "Indexing a numpy array with a fancy index always "
+ "results in a copy"
+ )
+ # Skip copy inside _common, even if array is not writeable
+ kwargs["copy"] = "_force_false" # type: ignore
+
+ res, x = self._common("get", **kwargs)
+ if res is not None:
+ return res
+ return x[self.idx]
+
+ def set(self, y, /, **kwargs):
+ """Apply ``x[idx] = y`` and return the update array"""
+ res, x = self._common("set", y, **kwargs)
+ if res is not None:
+ return res
+ x[self.idx] = y
+ return x
+
+ def _iop(
+ self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
+ ):
+ """x[idx] += y or equivalent in-place operation on a subset of x
+
+ which is the same as saying
+ x[idx] = x[idx] + y
+ Note that this is not the same as
+ operator.iadd(x[idx], y)
+ Consider for example when x is a numpy array and idx is a fancy index, which
+ triggers a deep copy on __getitem__.
+ """
+ res, x = self._common(at_op, y, **kwargs)
+ if res is not None:
+ return res
+ x[self.idx] = elwise_op(x[self.idx], y)
+ return x
+
+ def add(self, y, /, **kwargs):
+ """Apply ``x[idx] += y`` and return the updated array"""
+ return self._iop("add", operator.add, y, **kwargs)
+
+ def subtract(self, y, /, **kwargs):
+ """Apply ``x[idx] -= y`` and return the updated array"""
+ return self._iop("subtract", operator.sub, y, **kwargs)
+
+ def multiply(self, y, /, **kwargs):
+ """Apply ``x[idx] *= y`` and return the updated array"""
+ return self._iop("multiply", operator.mul, y, **kwargs)
+
+ def divide(self, y, /, **kwargs):
+ """Apply ``x[idx] /= y`` and return the updated array"""
+ return self._iop("divide", operator.truediv, y, **kwargs)
+
+ def power(self, y, /, **kwargs):
+ """Apply ``x[idx] **= y`` and return the updated array"""
+ return self._iop("power", operator.pow, y, **kwargs)
+
+ def min(self, y, /, **kwargs):
+ """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
+ xp = array_namespace(self.x)
+ y = xp.asarray(y)
+ return self._iop("min", xp.minimum, y, **kwargs)
+
+ def max(self, y, /, **kwargs):
+ """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
+ xp = array_namespace(self.x)
+ y = xp.asarray(y)
+ return self._iop("max", xp.maximum, y, **kwargs)
+
+
__all__ = [
"array_namespace",
"device",
@@ -821,8 +1070,10 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
+ "is_writeable_array",
"size",
"to_device",
+ "at",
]
-_all_ignore = ['sys', 'math', 'inspect', 'warnings']
+_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']
diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst
index f44dc070..e244cd5a 100644
--- a/docs/helper-functions.rst
+++ b/docs/helper-functions.rst
@@ -36,6 +36,8 @@ instead, which would be wrapped.
.. autofunction:: device
.. autofunction:: to_device
.. autofunction:: size
+.. autoclass:: at(array[, index])
+ :members:
Inspection Helpers
------------------
@@ -51,6 +53,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
+.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
diff --git a/tests/test_at.py b/tests/test_at.py
new file mode 100644
index 00000000..3abea959
--- /dev/null
+++ b/tests/test_at.py
@@ -0,0 +1,168 @@
+from __future__ import annotations
+
+from contextlib import contextmanager, suppress
+
+import numpy as np
+import pytest
+
+from array_api_compat import (
+ array_namespace,
+ at,
+ is_cupy_array,
+ is_dask_array,
+ is_jax_array,
+ is_torch_array,
+ is_torch_namespace,
+ is_pydata_sparse_array,
+ is_writeable_array,
+)
+from ._helpers import import_, all_libraries
+
+
+def assert_array_equal(a, b):
+ if is_pydata_sparse_array(a):
+ a = a.todense()
+ elif is_cupy_array(a):
+ a = a.get()
+ elif is_dask_array(a):
+ a = a.compute()
+ np.testing.assert_array_equal(a, b)
+
+
+@contextmanager
+def assert_copy(x, copy: bool | None):
+ # dask arrays are writeable, but writing to them will hot-swap the
+ # dask graph inside the collection so that anything that references
+ # the original graph, i.e. the input collection, won't be mutated.
+ if copy is False and (not is_writeable_array(x) or is_dask_array(x)):
+ with pytest.raises((TypeError, ValueError)):
+ yield
+ return
+
+ xp = array_namespace(x)
+ x_orig = xp.asarray(x, copy=True)
+ yield
+
+ expect_copy = not is_writeable_array(x) if copy is None else copy
+ assert_array_equal((x == x_orig).all(), expect_copy)
+
+
+@pytest.fixture(params=all_libraries + ["np_readonly"])
+def x(request):
+ library = request.param
+ if library == "np_readonly":
+ x = np.asarray([10, 20, 30])
+ x.flags.writeable = False
+ else:
+ lib = import_(library)
+ x = lib.asarray([10, 20, 30])
+ return x
+
+
+@pytest.mark.parametrize("copy", [True, False, None])
+@pytest.mark.parametrize(
+ "op,arg,expect",
+ [
+ ("set", 40, [10, 40, 40]),
+ ("add", 40, [10, 60, 70]),
+ ("subtract", 100, [10, -80, -70]),
+ ("multiply", 2, [10, 40, 60]),
+ ("divide", 2, [10, 10, 15]),
+ ("power", 2, [10, 400, 900]),
+ ("min", 25, [10, 20, 25]),
+ ("max", 25, [10, 25, 30]),
+ ],
+)
+def test_update_ops(x, copy, op, arg, expect):
+ if is_pydata_sparse_array(x):
+ pytest.skip("at() does not support updates on sparse arrays")
+
+ with assert_copy(x, copy):
+ y = getattr(at(x, slice(1, None)), op)(arg, copy=copy)
+ assert isinstance(y, type(x))
+ assert_array_equal(y, expect)
+
+
+@pytest.mark.parametrize("copy", [True, False, None])
+def test_get(x, copy):
+
+ expect_copy = copy
+ if is_dask_array(x) and copy is None:
+ # dask is mutable, but __getitem__ never returns a view
+ expect_copy = True
+
+ with assert_copy(x, expect_copy):
+ y = at(x, slice(2)).get(copy=copy)
+ assert isinstance(y, type(x))
+ assert_array_equal(y, [10, 20])
+ # Let assert_copy test that y is a view or copy
+ with suppress((TypeError, ValueError)):
+ y[0] = 40
+
+
+@pytest.mark.parametrize(
+ "idx",
+ [
+ [0, 1],
+ (0, 1),
+ np.array([0, 1], dtype="int32"),
+ np.array([0, 1], dtype="uint32"),
+ # torch only supports tensors of native integers as indices
+ lambda xp: xp.asarray([0, 1], dtype=None if is_torch_namespace(xp) else "int32"),
+ lambda xp: xp.asarray([0, 1], dtype=None if is_torch_namespace(xp) else "uint32"),
+ [True, True, False],
+ (True, True, False),
+ np.array([True, True, False]),
+ lambda xp: xp.asarray([True, True, False]),
+ ],
+)
+@pytest.mark.parametrize("tuple_index", [True, False])
+def test_get_fancy_indices(x, idx, tuple_index):
+ """get() with a fancy index always returns a copy"""
+ if callable(idx):
+ xp = array_namespace(x)
+ idx = idx(xp)
+
+ if is_jax_array(x) and isinstance(idx, (list, tuple)):
+ pytest.skip("JAX fancy indices must always be arrays")
+ if is_pydata_sparse_array(x) and is_pydata_sparse_array(idx):
+ pytest.skip("sparse fancy indices can't be sparse themselves")
+ if is_torch_array(x) and isinstance(idx, np.ndarray) and idx.dtype.kind == "u":
+ pytest.skip("torch does not support unsigned integer fancy indices")
+ if is_dask_array(x) and isinstance(idx, tuple):
+ pytest.skip("dask does not support tuples; only lists or arrays")
+ if isinstance(idx, tuple) and not tuple_index:
+ pytest.skip("tuple indices must always be wrapped in a tuple")
+
+ if tuple_index:
+ idx = (idx,)
+
+ with assert_copy(x, True):
+ y = at(x, idx).get()
+ assert isinstance(y, type(x))
+ assert_array_equal(y, [10, 20])
+ # Let assert_copy test that y is a view or copy
+ with suppress((TypeError, ValueError)):
+ y[0] = 40
+
+ with assert_copy(x, True):
+ y = at(x, idx).get(copy=None)
+ assert isinstance(y, type(x))
+ assert_array_equal(y, [10, 20])
+ # Let assert_copy test that y is a view or copy
+ with suppress((TypeError, ValueError)):
+ y[0] = 40
+
+ with pytest.raises(TypeError, match="fancy index"):
+ at(x, idx).get(copy=False)
+
+
+def test_variant_index_syntax(x):
+ y = at(x)[:2].get()
+ assert isinstance(y, type(x))
+ assert_array_equal(y, [10, 20])
+
+ with pytest.raises(ValueError):
+ at(x, 1)[2]
+ with pytest.raises(ValueError):
+ at(x)[1][2]
diff --git a/tests/test_common.py b/tests/test_common.py
index e1cfa9eb..2a74dd88 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -5,7 +5,7 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)
-from array_api_compat import is_array_api_obj, device, to_device
+from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
from ._helpers import import_, wrapped_libraries, all_libraries
@@ -55,6 +55,24 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])
+@pytest.mark.parametrize("library", all_libraries)
+def test_is_writeable_array(library):
+ lib = import_(library)
+ x = lib.asarray([1, 2, 3])
+ if is_writeable_array(x):
+ x[1] = 4
+ else:
+ with pytest.raises((TypeError, ValueError)):
+ x[1] = 4
+
+
+def test_is_writeable_array_numpy():
+ x = np.asarray([1, 2, 3])
+ assert is_writeable_array(x)
+ x.flags.writeable = False
+ assert not is_writeable_array(x)
+
+
@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)