Skip to content

Commit

Permalink
Revert where
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 27, 2024
1 parent 437d73a commit 53a4ac9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 49 deletions.
60 changes: 12 additions & 48 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,23 +815,6 @@ def is_writeable_array(x):
return True


def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
"""Preprocess and validate a copy parameter, in line with the same
parameter in np.asarray(), np.astype(), etc.
"""
if copy is True:
return True
elif copy is False:
if not is_writeable_array(x):
raise ValueError("Cannot avoid modifying parameter in place")
return False
elif copy is None:
return not is_writeable_array(x)
elif copy == "_force_false":
return False
raise ValueError(f"Invalid value for copy: {copy!r}")


_undef = object()


Expand Down Expand Up @@ -947,7 +930,15 @@ def _common(
"(same for all other methods)."
)

copy = _parse_copy_param(self.x, copy)
if copy is False:
if not is_writeable_array(self.x):
raise ValueError("Cannot avoid modifying parameter in place")
elif copy is None:
copy = not is_writeable_array(self.x)
elif copy == "_force_false":
copy = False
elif copy is not True:
raise ValueError(f"Invalid value for copy: {copy!r}")

if copy and is_jax_array(self.x):
# Use JAX's at[]
Expand All @@ -956,6 +947,9 @@ def _common(
return getattr(at_, at_op)(*args, **kwargs), None

# Emulate at[] behaviour for non-JAX arrays
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
x = self.x.copy() if copy else self.x
return None, x

Expand Down Expand Up @@ -1047,35 +1041,6 @@ def max(self, y, /, **kwargs):
return self._iop("max", xp.maximum, y, **kwargs)


def where(condition, x=None, y=None, /, copy: bool | None = True):
"""Return elements from x when condition is True and from y when
it is False.
This is a wrapper around xp.where that adds the copy parameter:
None
x *may* be modified in place if it is possible and beneficial
for performance. You should not use x after calling this function.
True
Ensure that the inputs are not modified.
This is the default, in line with np.where.
False
Raise ValueError if a copy cannot be avoided.
"""
if x is None and y is None:
xp = array_namespace(condition, use_compat=False)
return xp.where(condition)

copy = _parse_copy_param(x, copy)
xp = array_namespace(condition, x, y, use_compat=False)
if copy:
return xp.where(condition, x, y)
else:
condition, x, y = xp.broadcast_arrays(condition, x, y)
x[condition] = y[condition]
return x


__all__ = [
"array_namespace",
"device",
Expand All @@ -1100,7 +1065,6 @@ def where(condition, x=None, y=None, /, copy: bool | None = True):
"size",
"to_device",
"at",
"where",
]

_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']
1 change: 0 additions & 1 deletion docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ instead, which would be wrapped.
.. autofunction:: to_device
.. autofunction:: size
.. autofunction:: at
.. autofunction:: where

Inspection Helpers
------------------
Expand Down

0 comments on commit 53a4ac9

Please sign in to comment.