Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Dec 10, 2024
1 parent 7ae5766 commit 55a039a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
23 changes: 11 additions & 12 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# https://github.com/pylint-dev/pylint/issues/10112
from collections.abc import Callable # pylint: disable=import-error
from typing import ClassVar
from typing import ClassVar, Literal

from ._lib import _utils
from ._lib._compat import (
Expand Down Expand Up @@ -659,11 +659,11 @@ class at: # pylint: disable=invalid-name
idx: Index
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")

def __init__(self, x: Array, idx: Index = _undef, /):
def __init__(self, x: Array, idx: Index = _undef, /) -> None:
self.x = x
self.idx = idx

def __getitem__(self, idx: Index) -> at:
def __getitem__(self, idx: Index, /) -> at:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
Expand Down Expand Up @@ -704,19 +704,16 @@ def _common(

x = self.x

if copy is True:
if copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
elif copy:
writeable = None
elif copy is False:
else:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
elif copy is None: # type: ignore[redundant-expr]
writeable = is_writeable_array(x)
copy = _is_update and not writeable
else:
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
raise ValueError(msg)

if copy:
try:
Expand Down Expand Up @@ -782,7 +779,9 @@ def set(self, y: Array, /, **kwargs: Untyped) -> Array:

def _iop(
self,
at_op: str,
at_op: Literal[
"set", "add", "subtract", "multiply", "divide", "power", "min", "max"
],
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
Expand Down
16 changes: 12 additions & 4 deletions src/array_api_extra/_lib/_typing.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing
from collections.abc import Mapping
from types import ModuleType
from typing import Any
from typing import Any, Protocol

if typing.TYPE_CHECKING:
from typing_extensions import override

# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]
Device = Any # type: ignore[no-any-explicit]
Index = Any # type: ignore[no-any-explicit]
Untyped = Any # type: ignore[no-any-explicit]
Array = Untyped
Device = Untyped
Index = Untyped

class CanAt(Protocol):
@property
def at(self) -> Mapping[Index, Untyped]: ...

else:

def no_op_decorator(f): # pyright: ignore[reportUnreachable]
return f

override = no_op_decorator

CanAt = object

__all__ = ["ModuleType", "override"]
if typing.TYPE_CHECKING:
__all__ += ["Array", "Device", "Index", "Untyped"]
4 changes: 2 additions & 2 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from contextlib import contextmanager, suppress
from importlib import import_module
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Final

import numpy as np
import pytest
Expand All @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from array_api_extra._lib._typing import Array, Untyped

all_libraries = (
all_libraries: Final = (
"array_api_strict",
"numpy",
"numpy_readonly",
Expand Down

0 comments on commit 55a039a

Please sign in to comment.