diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 291fae8a..7b23c910 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device +from ._helpers import array_namespace, _check_device, device, is_torch_array # These functions are modified from the NumPy versions. @@ -281,10 +281,11 @@ def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape - result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) wrapped_xp = array_namespace(x) + result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) + # np.clip does type promotion but the array API clip requires that the # output have the same dtype as x. We do this instead of just downcasting # the result of xp.clip() to handle some corner cases better (e.g., @@ -305,20 +306,26 @@ def _isscalar(a): # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= xp.iinfo(x.dtype).min: + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: min = None - if type(max) is int and max >= xp.iinfo(x.dtype).max: + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) + out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), + copy=True, device=device(x)) if min is not None: - a = xp.broadcast_to(xp.asarray(min), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): + # Avoid loss of precision due to torch defaulting to float32 + min = wrapped_xp.asarray(min, dtype=xp.float64) + a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) ia = (out < a) | xp.isnan(a) # torch requires an explicit cast here out[ia] = wrapped_xp.astype(a[ia], out.dtype) if max is not None: - b = xp.broadcast_to(xp.asarray(max), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): + max = wrapped_xp.asarray(max, dtype=xp.float64) + b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) ib = (out > b) | xp.isnan(b) out[ib] = wrapped_xp.astype(b[ib], out.dtype) # Return a scalar for 0-D diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 7519f59f..10a03cd8 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -89,7 +89,6 @@ def _dask_arange( permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) -clip = get_xp(da)(_aliases.clip) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) @@ -167,6 +166,43 @@ def asarray( concatenate as concat, ) +# dask.array.clip does not work unless all three arguments are provided. +# Furthermore, the masking workaround in common._aliases.clip cannot work with +# dask (meaning uint64 promoting to float64 is going to just be unfixed for +# now). +@get_xp(da) +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, + *, + xp, +) -> Array: + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + + # TODO: This won't handle dask unknown shapes + import numpy as np + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + + if min is not None: + min = xp.broadcast_to(xp.asarray(min), result_shape) + if max is not None: + max = xp.broadcast_to(xp.asarray(max), result_shape) + + if min is None and max is None: + return xp.positive(x) + + if min is None: + return astype(xp.minimum(x, max), x.dtype) + if max is None: + return astype(xp.maximum(x, min), x.dtype) + + return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) + # exclude these from all since _da_unsupported = ['sort', 'argsort']