From 53a4ac96367fc06ef99512c5b58e1c5573fd0627 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 27 Nov 2024 14:00:51 +0000 Subject: [PATCH] Revert `where` --- array_api_compat/common/_helpers.py | 60 ++++++----------------------- docs/helper-functions.rst | 1 - 2 files changed, 12 insertions(+), 49 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a8a2ed89..4e6c94ee 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -815,23 +815,6 @@ def is_writeable_array(x): return True -def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool: - """Preprocess and validate a copy parameter, in line with the same - parameter in np.asarray(), np.astype(), etc. - """ - if copy is True: - return True - elif copy is False: - if not is_writeable_array(x): - raise ValueError("Cannot avoid modifying parameter in place") - return False - elif copy is None: - return not is_writeable_array(x) - elif copy == "_force_false": - return False - raise ValueError(f"Invalid value for copy: {copy!r}") - - _undef = object() @@ -947,7 +930,15 @@ def _common( "(same for all other methods)." ) - copy = _parse_copy_param(self.x, copy) + if copy is False: + if not is_writeable_array(self.x): + raise ValueError("Cannot avoid modifying parameter in place") + elif copy is None: + copy = not is_writeable_array(self.x) + elif copy == "_force_false": + copy = False + elif copy is not True: + raise ValueError(f"Invalid value for copy: {copy!r}") if copy and is_jax_array(self.x): # Use JAX's at[] @@ -956,6 +947,9 @@ def _common( return getattr(at_, at_op)(*args, **kwargs), None # Emulate at[] behaviour for non-JAX arrays + # FIXME We blindly expect the output of x.copy() to be always writeable. + # This holds true for read-only numpy arrays, but not necessarily for + # other backends. x = self.x.copy() if copy else self.x return None, x @@ -1047,35 +1041,6 @@ def max(self, y, /, **kwargs): return self._iop("max", xp.maximum, y, **kwargs) -def where(condition, x=None, y=None, /, copy: bool | None = True): - """Return elements from x when condition is True and from y when - it is False. - - This is a wrapper around xp.where that adds the copy parameter: - - None - x *may* be modified in place if it is possible and beneficial - for performance. You should not use x after calling this function. - True - Ensure that the inputs are not modified. - This is the default, in line with np.where. - False - Raise ValueError if a copy cannot be avoided. - """ - if x is None and y is None: - xp = array_namespace(condition, use_compat=False) - return xp.where(condition) - - copy = _parse_copy_param(x, copy) - xp = array_namespace(condition, x, y, use_compat=False) - if copy: - return xp.where(condition, x, y) - else: - condition, x, y = xp.broadcast_arrays(condition, x, y) - x[condition] = y[condition] - return x - - __all__ = [ "array_namespace", "device", @@ -1100,7 +1065,6 @@ def where(condition, x=None, y=None, /, copy: bool | None = True): "size", "to_device", "at", - "where", ] _all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys'] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index ce32713e..ddba6268 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -37,7 +37,6 @@ instead, which would be wrapped. .. autofunction:: to_device .. autofunction:: size .. autofunction:: at -.. autofunction:: where Inspection Helpers ------------------