diff --git a/pyproject.toml b/pyproject.toml index c8096d4..0eef94e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,6 +195,8 @@ reportAny = false reportExplicitAny = false # data-apis/array-api-strict#6 reportUnknownMemberType = false +# no array-api-compat type stubs +reportUnknownVariableType = false # Ruff diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 599048c..e0ca4fa 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -2,8 +2,10 @@ import operator import warnings -from collections.abc import Callable -from typing import Any + +# https://github.com/pylint-dev/pylint/issues/10112 +from collections.abc import Callable # pylint: disable=import-error +from typing import ClassVar from ._lib import _utils from ._lib._compat import ( @@ -12,7 +14,7 @@ is_dask_array, is_writeable_array, ) -from ._lib._typing import Array, ModuleType +from ._lib._typing import Array, Index, ModuleType, Untyped __all__ = [ "at", @@ -559,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: _undef = object() -class at: +class at: # pylint: disable=invalid-name """ Update operations for read-only arrays. @@ -651,14 +653,14 @@ class at: """ x: Array - idx: Any - __slots__ = ("idx", "x") + idx: Index + __slots__: ClassVar[tuple[str, str]] = ("idx", "x") - def __init__(self, x: Array, idx: Any = _undef, /): + def __init__(self, x: Array, idx: Index = _undef, /): self.x = x self.idx = idx - def __getitem__(self, idx: Any) -> Any: + def __getitem__(self, idx: Index) -> at: """Allow for the alternate syntax ``at(x)[start:stop:step]``, which looks prettier than ``at(x, slice(start, stop, step))`` and feels more intuitive coming from the JAX documentation. @@ -677,8 +679,8 @@ def _common( copy: bool | None = True, xp: ModuleType | None = None, _is_update: bool = True, - **kwargs: Any, - ) -> tuple[Any, None] | tuple[None, Array]: + **kwargs: Untyped, + ) -> tuple[Untyped, None] | tuple[None, Array]: """Perform common prepocessing. Returns @@ -706,11 +708,11 @@ def _common( if not writeable: msg = "Cannot modify parameter in place" raise ValueError(msg) - elif copy is None: + elif copy is None: # type: ignore[redundant-expr] writeable = is_writeable_array(x) copy = _is_update and not writeable else: - msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] + msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable] raise ValueError(msg) if copy: @@ -741,7 +743,7 @@ def _common( return None, x - def get(self, **kwargs: Any) -> Any: + def get(self, **kwargs: Untyped) -> Untyped: """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring that the output is either a copy or a view; it also allows passing keyword arguments to the backend. @@ -766,7 +768,7 @@ def get(self, **kwargs: Any) -> Any: assert x is not None return x[self.idx] - def set(self, y: Array, /, **kwargs: Any) -> Array: + def set(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = y`` and return the update array""" res, x = self._common("set", y, **kwargs) if res is not None: @@ -781,7 +783,7 @@ def _iop( elwise_op: Callable[[Array, Array], Array], y: Array, /, - **kwargs: Any, + **kwargs: Untyped, ) -> Array: """x[idx] += y or equivalent in-place operation on a subset of x @@ -799,33 +801,33 @@ def _iop( x[self.idx] = elwise_op(x[self.idx], y) return x - def add(self, y: Array, /, **kwargs: Any) -> Array: + def add(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] += y`` and return the updated array""" return self._iop("add", operator.add, y, **kwargs) - def subtract(self, y: Array, /, **kwargs: Any) -> Array: + def subtract(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] -= y`` and return the updated array""" return self._iop("subtract", operator.sub, y, **kwargs) - def multiply(self, y: Array, /, **kwargs: Any) -> Array: + def multiply(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] *= y`` and return the updated array""" return self._iop("multiply", operator.mul, y, **kwargs) - def divide(self, y: Array, /, **kwargs: Any) -> Array: + def divide(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] /= y`` and return the updated array""" return self._iop("divide", operator.truediv, y, **kwargs) - def power(self, y: Array, /, **kwargs: Any) -> Array: + def power(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] **= y`` and return the updated array""" return self._iop("power", operator.pow, y, **kwargs) - def min(self, y: Array, /, **kwargs: Any) -> Array: + def min(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" xp = array_namespace(self.x) y = xp.asarray(y) return self._iop("min", xp.minimum, y, **kwargs) - def max(self, y: Array, /, **kwargs: Any) -> Array: + def max(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" xp = array_namespace(self.x) y = xp.asarray(y) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 7189d38..20bbda9 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -4,19 +4,19 @@ try: from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] - array_namespace, # pyright: ignore[reportUnknownVariableType] - device, # pyright: ignore[reportUnknownVariableType] - is_array_api_obj, # pyright: ignore[reportUnknownVariableType] - is_dask_array, # pyright: ignore[reportUnknownVariableType] - is_writeable_array, # pyright: ignore[reportUnknownVariableType] + array_namespace, + device, + is_array_api_obj, + is_dask_array, + is_writeable_array, ) except ImportError: from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] - array_namespace, # pyright: ignore[reportUnknownVariableType] + array_namespace, device, - is_array_api_obj, # pyright: ignore[reportUnknownVariableType] - is_dask_array, # pyright: ignore[reportUnknownVariableType] - is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] + is_array_api_obj, + is_dask_array, + is_writeable_array, ) __all__ = ( diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_typing.py index f84b1d2..aa99a1a 100644 --- a/src/array_api_extra/_lib/_typing.py +++ b/src/array_api_extra/_lib/_typing.py @@ -10,6 +10,8 @@ # To be changed to a Protocol later (see data-apis/array-api#589) Array = Any # type: ignore[no-any-explicit] Device = Any # type: ignore[no-any-explicit] + Index = Any # type: ignore[no-any-explicit] + Untyped = Any # type: ignore[no-any-explicit] else: def no_op_decorator(f): # pyright: ignore[reportUnreachable] @@ -19,4 +21,4 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable] __all__ = ["ModuleType", "override"] if typing.TYPE_CHECKING: - __all__ += ["Array", "Device"] + __all__ += ["Array", "Device", "Index", "Untyped"] diff --git a/tests/test_at.py b/tests/test_at.py index d9ce49e..1c8fa93 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from array_api_compat import ( +from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] array_namespace, is_dask_array, is_pydata_sparse_array, @@ -16,7 +16,7 @@ from array_api_extra import at if TYPE_CHECKING: - from array_api_extra._lib._typing import Array + from array_api_extra._lib._typing import Array, Untyped all_libraries = ( "array_api_strict", @@ -31,7 +31,7 @@ @pytest.fixture(params=all_libraries) -def array(request): +def array(request: pytest.FixtureRequest) -> Array: library = request.param if library == "numpy_readonly": x = np.asarray([10.0, 20.0, 30.0]) @@ -55,7 +55,7 @@ def assert_array_equal(a: Array, b: Array) -> None: @contextmanager -def assert_copy(array, copy: bool | None): +def assert_copy(array: Array, copy: bool | None) -> Untyped: # type: ignore[no-any-decorated] # dask arrays are writeable, but writing to them will hot-swap the # dask graph inside the collection so that anything that references # the original graph, i.e. the input collection, won't be mutated. @@ -86,7 +86,9 @@ def assert_copy(array, copy: bool | None): ("max", 25.0, [10.0, 25.0, 30.0]), ], ) -def test_update_ops(array, copy, op, arg, expect): +def test_update_ops( + array: Array, copy: bool | None, op: str, arg: float, expect: list[float] +): if is_pydata_sparse_array(array): pytest.skip("at() does not support updates on sparse arrays") @@ -97,7 +99,7 @@ def test_update_ops(array, copy, op, arg, expect): @pytest.mark.parametrize("copy", [True, False, None]) -def test_get(array, copy): +def test_get(array: Array, copy: bool | None): expect_copy = copy # dask is mutable, but __getitem__ never returns a view @@ -117,7 +119,7 @@ def test_get(array, copy): y[:] = 40 -def test_get_bool_indices(array): +def test_get_bool_indices(array: Array): """get() with a boolean array index always returns a copy""" # sparse violates the array API as it doesn't support # a boolean index that is another sparse array.