Skip to content

Commit

Permalink
appease linter
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Dec 10, 2024
1 parent aa0d364 commit 78692b7
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 39 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ reportAny = false
reportExplicitAny = false
# data-apis/array-api-strict#6
reportUnknownMemberType = false
# no array-api-compat type stubs
reportUnknownVariableType = false


# Ruff
Expand Down
46 changes: 24 additions & 22 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import operator
import warnings
from collections.abc import Callable
from typing import Any

# https://github.com/pylint-dev/pylint/issues/10112
from collections.abc import Callable # pylint: disable=import-error
from typing import ClassVar

from ._lib import _utils
from ._lib._compat import (
Expand All @@ -12,7 +14,7 @@
is_dask_array,
is_writeable_array,
)
from ._lib._typing import Array, ModuleType
from ._lib._typing import Array, Index, ModuleType, Untyped

__all__ = [
"at",
Expand Down Expand Up @@ -559,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
_undef = object()


class at:
class at: # pylint: disable=invalid-name
"""
Update operations for read-only arrays.
Expand Down Expand Up @@ -651,14 +653,14 @@ class at:
"""

x: Array
idx: Any
__slots__ = ("idx", "x")
idx: Index
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")

def __init__(self, x: Array, idx: Any = _undef, /):
def __init__(self, x: Array, idx: Index = _undef, /):
self.x = x
self.idx = idx

def __getitem__(self, idx: Any) -> Any:
def __getitem__(self, idx: Index) -> at:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
Expand All @@ -677,8 +679,8 @@ def _common(
copy: bool | None = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Any,
) -> tuple[Any, None] | tuple[None, Array]:
**kwargs: Untyped,
) -> tuple[Untyped, None] | tuple[None, Array]:
"""Perform common prepocessing.
Returns
Expand Down Expand Up @@ -706,11 +708,11 @@ def _common(
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
elif copy is None:
elif copy is None: # type: ignore[redundant-expr]
writeable = is_writeable_array(x)
copy = _is_update and not writeable
else:
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
raise ValueError(msg)

if copy:
Expand Down Expand Up @@ -741,7 +743,7 @@ def _common(

return None, x

def get(self, **kwargs: Any) -> Any:
def get(self, **kwargs: Untyped) -> Untyped:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
Expand All @@ -766,7 +768,7 @@ def get(self, **kwargs: Any) -> Any:
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Any) -> Array:
def set(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
Expand All @@ -781,7 +783,7 @@ def _iop(
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
**kwargs: Any,
**kwargs: Untyped,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x
Expand All @@ -799,33 +801,33 @@ def _iop(
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Any) -> Array:
def add(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y: Array, /, **kwargs: Any) -> Array:
def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y: Array, /, **kwargs: Any) -> Array:
def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y: Array, /, **kwargs: Any) -> Array:
def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y: Array, /, **kwargs: Any) -> Array:
def power(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y: Array, /, **kwargs: Any) -> Array:
def min(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y: Array, /, **kwargs: Any) -> Array:
def max(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
Expand Down
18 changes: 9 additions & 9 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@

try:
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device, # pyright: ignore[reportUnknownVariableType]
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
is_dask_array, # pyright: ignore[reportUnknownVariableType]
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
array_namespace,
device,
is_array_api_obj,
is_dask_array,
is_writeable_array,
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
array_namespace,
device,
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
is_dask_array, # pyright: ignore[reportUnknownVariableType]
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
is_array_api_obj,
is_dask_array,
is_writeable_array,
)

__all__ = (
Expand Down
4 changes: 3 additions & 1 deletion src/array_api_extra/_lib/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]
Device = Any # type: ignore[no-any-explicit]
Index = Any # type: ignore[no-any-explicit]
Untyped = Any # type: ignore[no-any-explicit]
else:

def no_op_decorator(f): # pyright: ignore[reportUnreachable]
Expand All @@ -19,4 +21,4 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable]

__all__ = ["ModuleType", "override"]
if typing.TYPE_CHECKING:
__all__ += ["Array", "Device"]
__all__ += ["Array", "Device", "Index", "Untyped"]
16 changes: 9 additions & 7 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pytest
from array_api_compat import (
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
array_namespace,
is_dask_array,
is_pydata_sparse_array,
Expand All @@ -16,7 +16,7 @@
from array_api_extra import at

if TYPE_CHECKING:
from array_api_extra._lib._typing import Array
from array_api_extra._lib._typing import Array, Untyped

all_libraries = (
"array_api_strict",
Expand All @@ -31,7 +31,7 @@


@pytest.fixture(params=all_libraries)
def array(request):
def array(request: pytest.FixtureRequest) -> Array:
library = request.param
if library == "numpy_readonly":
x = np.asarray([10.0, 20.0, 30.0])
Expand All @@ -55,7 +55,7 @@ def assert_array_equal(a: Array, b: Array) -> None:


@contextmanager
def assert_copy(array, copy: bool | None):
def assert_copy(array: Array, copy: bool | None) -> Untyped: # type: ignore[no-any-decorated]
# dask arrays are writeable, but writing to them will hot-swap the
# dask graph inside the collection so that anything that references
# the original graph, i.e. the input collection, won't be mutated.
Expand Down Expand Up @@ -86,7 +86,9 @@ def assert_copy(array, copy: bool | None):
("max", 25.0, [10.0, 25.0, 30.0]),
],
)
def test_update_ops(array, copy, op, arg, expect):
def test_update_ops(
array: Array, copy: bool | None, op: str, arg: float, expect: list[float]
):
if is_pydata_sparse_array(array):
pytest.skip("at() does not support updates on sparse arrays")

Expand All @@ -97,7 +99,7 @@ def test_update_ops(array, copy, op, arg, expect):


@pytest.mark.parametrize("copy", [True, False, None])
def test_get(array, copy):
def test_get(array: Array, copy: bool | None):
expect_copy = copy

# dask is mutable, but __getitem__ never returns a view
Expand All @@ -117,7 +119,7 @@ def test_get(array, copy):
y[:] = 40


def test_get_bool_indices(array):
def test_get_bool_indices(array: Array):
"""get() with a boolean array index always returns a copy"""
# sparse violates the array API as it doesn't support
# a boolean index that is another sparse array.
Expand Down

0 comments on commit 78692b7

Please sign in to comment.