From aa0d36436f7e9102b384f1b489139b79c8a0f278 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 10 Dec 2024 12:36:18 +0000 Subject: [PATCH] WIP at() method --- docs/api-reference.md | 1 + pyproject.toml | 3 + src/array_api_extra/__init__.py | 12 +- src/array_api_extra/_funcs.py | 292 ++++++++++++++++++++++++++- src/array_api_extra/_lib/_compat.py | 13 +- src/array_api_extra/_lib/_compat.pyi | 3 + tests/test_at.py | 153 ++++++++++++++ vendor_tests/test_vendor.py | 15 +- 8 files changed, 482 insertions(+), 10 deletions(-) create mode 100644 tests/test_at.py diff --git a/docs/api-reference.md b/docs/api-reference.md index ffe68f2..b43c960 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -6,6 +6,7 @@ :nosignatures: :toctree: generated + at atleast_nd cov create_diagonal diff --git a/pyproject.toml b/pyproject.toml index f1a6f49..c8096d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,6 +230,9 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "ISC001", # Conflicts with formatter + # "N802", # Function name should be lowercase + # "N806", # Variable in function should be lowercase + # "PD008", # pandas-use-of-dot-at ] [tool.ruff.lint.per-file-ignores] diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index d1107b1..bd676fe 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,12 +1,22 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc +from ._funcs import ( + at, + atleast_nd, + cov, + create_diagonal, + expand_dims, + kron, + setdiff1d, + sinc, +) __version__ = "0.3.3.dev0" # pylint: disable=duplicate-code __all__ = [ "__version__", + "at", "atleast_nd", "cov", "create_diagonal", diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 3d961e2..599048c 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,15 +1,21 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 -import typing +import operator import warnings - -if typing.TYPE_CHECKING: - from ._lib._typing import Array, ModuleType +from collections.abc import Callable +from typing import Any from ._lib import _utils -from ._lib._compat import array_namespace +from ._lib._compat import ( + array_namespace, + is_array_api_obj, + is_dask_array, + is_writeable_array, +) +from ._lib._typing import Array, ModuleType __all__ = [ + "at", "atleast_nd", "cov", "create_diagonal", @@ -548,3 +554,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), ) return xp.sin(y) / y + + +_undef = object() + + +class at: + """ + Update operations for read-only arrays. + + This implements ``jax.numpy.ndarray.at`` for all backends. + + Parameters + ---------- + x : array + Input array. + idx : index, optional + You may use two alternate syntaxes:: + + at(x, idx).set(value) # or get(), add(), etc. + at(x)[idx].set(value) + + copy : bool, optional + True (default) + Ensure that the inputs are not modified. + False + Ensure that the update operation writes back to the input. + Raise ValueError if a copy cannot be avoided. + None + The array parameter *may* be modified in place if it is possible and + beneficial for performance. + You should not reuse it after calling this function. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer + + **kwargs: + If the backend supports an `at` method, any additional keyword + arguments are passed to it verbatim; e.g. this allows passing + ``indices_are_sorted=True`` to JAX. + + Returns + ------- + Updated input array. + + Examples + -------- + Given either of these equivalent expressions:: + + x = at(x)[1].add(2, copy=None) + x = at(x, 1).add(2, copy=None) + + If x is a JAX array, they are the same as:: + + x = x.at[1].add(2) + + If x is a read-only numpy array, they are the same as:: + + x = x.copy() + x[1] += 2 + + Otherwise, they are the same as:: + + x[1] += 2 + + Warning + ------- + When you use copy=None, you should always immediately overwrite + the parameter array:: + + x = at(x, 0).set(2, copy=None) + + The anti-pattern below must be avoided, as it will result in different behaviour + on read-only versus writeable arrays:: + + x = xp.asarray([0, 0, 0]) + y = at(x, 0).set(2, copy=None) + z = at(x, 1).set(3, copy=None) + + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! + + Warning + ------- + The array API standard does not support integer array indices. + The behaviour of update methods when the index is an array of integers + is undefined; this is particularly true when the index contains multiple + occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``. + + Note + ---- + `sparse `_ is not supported by update methods yet. + + See Also + -------- + `jax.numpy.ndarray.at `_ + """ + + x: Array + idx: Any + __slots__ = ("idx", "x") + + def __init__(self, x: Array, idx: Any = _undef, /): + self.x = x + self.idx = idx + + def __getitem__(self, idx: Any) -> Any: + """Allow for the alternate syntax ``at(x)[start:stop:step]``, + which looks prettier than ``at(x, slice(start, stop, step))`` + and feels more intuitive coming from the JAX documentation. + """ + if self.idx is not _undef: + msg = "Index has already been set" + raise ValueError(msg) + self.idx = idx + return self + + def _common( + self, + at_op: str, + y: Array = _undef, + /, + copy: bool | None = True, + xp: ModuleType | None = None, + _is_update: bool = True, + **kwargs: Any, + ) -> tuple[Any, None] | tuple[None, Array]: + """Perform common prepocessing. + + Returns + ------- + If the operation can be resolved by at[], (return value, None) + Otherwise, (None, preprocessed x) + """ + if self.idx is _undef: + msg = ( + "Index has not been set.\n" + "Usage: either\n" + " at(x, idx).set(value)\n" + "or\n" + " at(x)[idx].set(value)\n" + "(same for all other methods)." + ) + raise TypeError(msg) + + x = self.x + + if copy is True: + writeable = None + elif copy is False: + writeable = is_writeable_array(x) + if not writeable: + msg = "Cannot modify parameter in place" + raise ValueError(msg) + elif copy is None: + writeable = is_writeable_array(x) + copy = _is_update and not writeable + else: + msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] + raise ValueError(msg) + + if copy: + try: + at_ = x.at + except AttributeError: + # Emulate at[] behaviour for non-JAX arrays + # with a copy followed by an update + if xp is None: + xp = array_namespace(x) + # Create writeable copy of read-only numpy array + x = xp.asarray(x, copy=True) + if writeable is False: + # A copy of a read-only numpy array is writeable + writeable = None + else: + # Use JAX's at[] or other library that with the same duck-type API + args = (y,) if y is not _undef else () + return getattr(at_[self.idx], at_op)(*args, **kwargs), None + + if _is_update: + if writeable is None: + writeable = is_writeable_array(x) + if not writeable: + # sparse crashes here + msg = f"Array {x} has no `at` method and is read-only" + raise ValueError(msg) + + return None, x + + def get(self, **kwargs: Any) -> Any: + """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring + that the output is either a copy or a view; it also allows passing + keyword arguments to the backend. + """ + if kwargs.get("copy") is False: + if is_array_api_obj(self.idx): + # Boolean index. Note that the array API spec + # https://data-apis.org/array-api/latest/API_specification/indexing.html + # does not allow for list, tuple, and tuples of slices plus one or more + # one-dimensional array indices, although many backends support them. + # So this check will encounter a lot of false negatives in real life, + # which can be caught by testing the user code vs. array-api-strict. + msg = "get() with an array index always returns a copy" + raise ValueError(msg) + if is_dask_array(self.x): + msg = "get() on Dask arrays always returns a copy" + raise ValueError(msg) + + res, x = self._common("get", _is_update=False, **kwargs) + if res is not None: + return res + assert x is not None + return x[self.idx] + + def set(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = y`` and return the update array""" + res, x = self._common("set", y, **kwargs) + if res is not None: + return res + assert x is not None + x[self.idx] = y + return x + + def _iop( + self, + at_op: str, + elwise_op: Callable[[Array, Array], Array], + y: Array, + /, + **kwargs: Any, + ) -> Array: + """x[idx] += y or equivalent in-place operation on a subset of x + + which is the same as saying + x[idx] = x[idx] + y + Note that this is not the same as + operator.iadd(x[idx], y) + Consider for example when x is a numpy array and idx is a fancy index, which + triggers a deep copy on __getitem__. + """ + res, x = self._common(at_op, y, **kwargs) + if res is not None: + return res + assert x is not None + x[self.idx] = elwise_op(x[self.idx], y) + return x + + def add(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] += y`` and return the updated array""" + return self._iop("add", operator.add, y, **kwargs) + + def subtract(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] -= y`` and return the updated array""" + return self._iop("subtract", operator.sub, y, **kwargs) + + def multiply(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] *= y`` and return the updated array""" + return self._iop("multiply", operator.mul, y, **kwargs) + + def divide(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] /= y`` and return the updated array""" + return self._iop("divide", operator.truediv, y, **kwargs) + + def power(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] **= y`` and return the updated array""" + return self._iop("power", operator.pow, y, **kwargs) + + def min(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("min", xp.minimum, y, **kwargs) + + def max(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("max", xp.maximum, y, **kwargs) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 03e47d1..7189d38 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -6,14 +6,23 @@ from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] array_namespace, # pyright: ignore[reportUnknownVariableType] device, # pyright: ignore[reportUnknownVariableType] + is_array_api_obj, # pyright: ignore[reportUnknownVariableType] + is_dask_array, # pyright: ignore[reportUnknownVariableType] + is_writeable_array, # pyright: ignore[reportUnknownVariableType] ) except ImportError: from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] array_namespace, # pyright: ignore[reportUnknownVariableType] device, + is_array_api_obj, # pyright: ignore[reportUnknownVariableType] + is_dask_array, # pyright: ignore[reportUnknownVariableType] + is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] ) -__all__ = [ +__all__ = ( "array_namespace", "device", -] + "is_array_api_obj", + "is_dask_array", + "is_writeable_array", +) diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi index 3b4eb43..ec0ece5 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_compat.pyi @@ -11,3 +11,6 @@ def array_namespace( use_compat: bool | None = None, ) -> ArrayModule: ... def device(x: Array, /) -> Device: ... +def is_array_api_obj(x: object, /) -> bool: ... +def is_dask_array(x: object, /) -> bool: ... +def is_writeable_array(x: object, /) -> bool: ... diff --git a/tests/test_at.py b/tests/test_at.py new file mode 100644 index 0000000..d9ce49e --- /dev/null +++ b/tests/test_at.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from contextlib import contextmanager, suppress +from importlib import import_module +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from array_api_compat import ( + array_namespace, + is_dask_array, + is_pydata_sparse_array, + is_writeable_array, +) + +from array_api_extra import at + +if TYPE_CHECKING: + from array_api_extra._lib._typing import Array + +all_libraries = ( + "array_api_strict", + "numpy", + "numpy_readonly", + "cupy", + "torch", + "dask.array", + "sparse", + "jax.numpy", +) + + +@pytest.fixture(params=all_libraries) +def array(request): + library = request.param + if library == "numpy_readonly": + x = np.asarray([10.0, 20.0, 30.0]) + x.flags.writeable = False + else: + try: + lib = import_module(library) + except ImportError: + pytest.skip(f"{library} is not installed") + x = lib.asarray([10.0, 20.0, 30.0]) + return x + + +def assert_array_equal(a: Array, b: Array) -> None: + xp = array_namespace(a) + b = xp.asarray(b) + eq = xp.all(a == b) + if is_dask_array(a): + eq = eq.compute() + assert eq + + +@contextmanager +def assert_copy(array, copy: bool | None): + # dask arrays are writeable, but writing to them will hot-swap the + # dask graph inside the collection so that anything that references + # the original graph, i.e. the input collection, won't be mutated. + if copy is False and not is_writeable_array(array): + with pytest.raises((TypeError, ValueError)): + yield + return + + xp = array_namespace(array) + array_orig = xp.asarray(array, copy=True) + yield + + expect_copy = not is_writeable_array(array) if copy is None else copy + assert_array_equal(xp.all(array == array_orig), expect_copy) + + +@pytest.mark.parametrize("copy", [True, False, None]) +@pytest.mark.parametrize( + ("op", "arg", "expect"), + [ + ("set", 40.0, [10.0, 40.0, 40.0]), + ("add", 40.0, [10.0, 60.0, 70.0]), + ("subtract", 100.0, [10.0, -80.0, -70.0]), + ("multiply", 2.0, [10.0, 40.0, 60.0]), + ("divide", 2.0, [10.0, 10.0, 15.0]), + ("power", 2.0, [10.0, 400.0, 900.0]), + ("min", 25.0, [10.0, 20.0, 25.0]), + ("max", 25.0, [10.0, 25.0, 30.0]), + ], +) +def test_update_ops(array, copy, op, arg, expect): + if is_pydata_sparse_array(array): + pytest.skip("at() does not support updates on sparse arrays") + + with assert_copy(array, copy): + y = getattr(at(array, slice(1, None)), op)(arg, copy=copy) + assert isinstance(y, type(array)) + assert_array_equal(y, expect) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_get(array, copy): + expect_copy = copy + + # dask is mutable, but __getitem__ never returns a view + if is_dask_array(array): + if copy is False: + with pytest.raises(ValueError, match="always returns a copy"): + at(array, slice(2)).get(copy=False) + return + expect_copy = True + + with assert_copy(array, expect_copy): + y = at(array, slice(2)).get(copy=copy) + assert isinstance(y, type(array)) + assert_array_equal(y, [10.0, 20.0]) + # Let assert_copy test that y is a view or copy + with suppress(TypeError, ValueError): + y[:] = 40 + + +def test_get_bool_indices(array): + """get() with a boolean array index always returns a copy""" + # sparse violates the array API as it doesn't support + # a boolean index that is another sparse array. + # dask with dask index has NaN size, which complicates testing. + if is_pydata_sparse_array(array) or is_dask_array(array): + xp = np + else: + xp = array_namespace(array) + idx = xp.asarray([True, False, True]) + + with pytest.raises(ValueError, match="copy"): + at(array, idx).get(copy=False) + + assert_array_equal(at(array, idx).get(), [10.0, 30.0]) + + with assert_copy(array, True): + y = at(array, idx).get(copy=True) + assert_array_equal(y, [10.0, 30.0]) + # Let assert_copy test that y is a view or copy + with suppress(TypeError, ValueError): + y[:] = 40 + + +def test_copy_invalid(): + a = np.asarray([1, 2, 3]) + with pytest.raises(ValueError, match="copy"): + at(a, 0).set(4, copy="invalid") + + +def test_xp(): + a = np.asarray([1, 2, 3]) + b = at(a, 0).set(4, xp=np) + assert_array_equal(b, [4, 2, 3]) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 8b00a37..d549d90 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -5,10 +5,21 @@ def test_vendor_compat(): - from ._array_api_compat_vendor import array_namespace + from ._array_api_compat_vendor import ( # type: ignore[attr-defined] + array_namespace, + device, + is_array_api_obj, + is_dask_array, + is_writeable_array, + ) x = xp.asarray([1, 2, 3]) - assert array_namespace(x) is xp + assert array_namespace(x) is xp # type: ignore[no-untyped-call] + device(x) + assert is_array_api_obj(x) + assert not is_array_api_obj(123) + assert not is_dask_array(x) + assert is_writeable_array(x) def test_vendor_extra():