diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py index 9f96fa9f..1016312d 100644 --- a/array_api_compat/paddle/__init__.py +++ b/array_api_compat/paddle/__init__.py @@ -4,16 +4,10 @@ import paddle for n in dir(paddle): - if ( - n.startswith("_") - or n.endswith("_") - or "gpu" in n - or "cpu" in n - or "backward" in n - ): + if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n: continue - exec(n + " = paddle." + n) - exec("asarray = paddle.to_tensor") + exec(f"{n} = paddle.{n}") + # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py index 14d3de7f..601afa5f 100644 --- a/array_api_compat/paddle/_aliases.py +++ b/array_api_compat/paddle/_aliases.py @@ -1,14 +1,17 @@ from __future__ import annotations +from typing import Literal +import numpy as np + from functools import wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from ..common._aliases import ( - matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, - clip as _aliases_clip, unstack as _aliases_unstack, - cumulative_sum as _aliases_cumulative_sum, +) +from ..common._typing import ( + SupportsBufferProtocol, + NestedSequence, ) from .._internal import get_xp @@ -94,7 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True): return x1, x2 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: return x1, x2 - # If an argument is 0-D pytorch downcasts the other argument + # If an argument is 0-D paddle downcasts the other argument if not only_scalar or x1.shape == (): dtype = result_type(x1, x2) x2 = x2.to(dtype) @@ -131,6 +134,12 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + if paddle.is_tensor(from_): + from_ = from_.dtype + + assert isinstance(from_, paddle.dtype), from_.dtype + assert isinstance(to, paddle.dtype), to.dtype + can_cast_dict = { paddle.bfloat16: { paddle.bfloat16: True, @@ -341,9 +350,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: remainder = _two_arg(paddle.remainder) subtract = _two_arg(paddle.subtract) -# These wrappers are mostly based on the fact that pytorch uses 'dim' instead -# of 'axis'. - def max( x: array, @@ -352,12 +358,21 @@ def max( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> array: - # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return paddle.clone(x) return paddle.amax(x, axis, keepdim=keepdims) +def argmax( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + return paddle.argmax(x, axis, keepdim=keepdims) + + def min( x: array, /, @@ -365,19 +380,25 @@ def min( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> array: - # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return paddle.clone(x) return paddle.min(x, axis, keepdim=keepdims) -clip = get_xp(paddle)(_aliases_clip) +def argmin( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + return paddle.argmin(x, axis, keepdim=keepdims) + + unstack = get_xp(paddle)(_aliases_unstack) -cumulative_sum = get_xp(paddle)(_aliases_cumulative_sum) # paddle.sort also returns a tuple -# https://github.com/pytorch/pytorch/issues/70921 def sort( x: array, /, @@ -387,9 +408,7 @@ def sort( stable: bool = True, **kwargs, ) -> array: - return paddle.sort( - x, axis=axis, descending=descending, stable=stable, **kwargs - ).values + return paddle.sort(x, axis=axis, descending=descending, stable=stable, **kwargs) def _normalize_axes(axis, ndim): @@ -401,9 +420,7 @@ def _normalize_axes(axis, ndim): for a in axis: if a < lower or a > upper: # Match paddle error message (e.g., from sum()) - raise IndexError( - f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}" - ) + raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}") if a < 0: a = a + ndim if a in axes: @@ -415,7 +432,6 @@ def _normalize_axes(axis, ndim): def _axis_none_keepdims(x, ndim, keepdims): # Apply keepdims when axis=None - # (https://github.com/pytorch/pytorch/issues/71209) # Note that this is only valid for the axis=None case. if keepdims: for i in range(ndim): @@ -425,7 +441,6 @@ def _axis_none_keepdims(x, ndim, keepdims): def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): # Some reductions don't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). axes = _normalize_axes(axis, x.ndim) for a in reversed(axes): x = paddle.movedim(x, a, -1) @@ -448,10 +463,10 @@ def prod( keepdims: bool = False, **kwargs, ) -> array: - x = paddle.asarray(x) + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic # below because it still needs to upcast. if axis == (): if dtype is None: @@ -464,14 +479,10 @@ def prod( return x.to(dtype) # paddle.prod doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): - return _reduce_multiple_axes( - paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs - ) + return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.prod(x, dtype=dtype, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res @@ -488,10 +499,10 @@ def sum( keepdims: bool = False, **kwargs, ) -> array: - x = paddle.asarray(x) + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. # Make sure it upcasts. if axis == (): if dtype is None: @@ -505,7 +516,6 @@ def sum( if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.sum(x, dtype=dtype, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res @@ -521,18 +531,17 @@ def any( keepdims: bool = False, **kwargs, ) -> array: - x = paddle.asarray(x) + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) ndim = x.ndim if axis == (): return x.to(paddle.bool) # paddle.any doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs) return res.to(paddle.bool) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.any(x, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res.to(paddle.bool) @@ -549,18 +558,17 @@ def all( keepdims: bool = False, **kwargs, ) -> array: - x = paddle.asarray(x) + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) ndim = x.ndim if axis == (): return x.to(paddle.bool) # paddle.all doesn't support multiple axes - # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs) return res.to(paddle.bool) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.all(x, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res.to(paddle.bool) @@ -577,12 +585,10 @@ def mean( keepdims: bool = False, **kwargs, ) -> array: - # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return paddle.clone(x) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.mean(x, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res @@ -599,15 +605,12 @@ def std( **kwargs, ) -> array: # Note, float correction is not supported - # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. if isinstance(correction, float): _correction = int(correction) if correction != _correction: - raise NotImplementedError( - "float correction in paddle std() is not yet supported" - ) + raise NotImplementedError("float correction in paddle std() is not yet supported") elif isinstance(correction, int): if correction not in [0, 1]: raise NotImplementedError("correction only can be 0 or 1") @@ -616,14 +619,12 @@ def std( _correction = bool(_correction) - # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return paddle.zeros_like(x) if isinstance(axis, int): axis = (axis,) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res @@ -640,7 +641,6 @@ def var( **kwargs, ) -> array: # Note, float correction is not supported - # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. # if isinstance(correction, float): @@ -648,9 +648,7 @@ def var( if isinstance(correction, float): _correction = int(correction) if correction != _correction: - raise NotImplementedError( - "float correction in paddle std() is not yet supported" - ) + raise NotImplementedError("float correction in paddle std() is not yet supported") elif isinstance(correction, int): if correction not in [0, 1]: raise NotImplementedError("correction only can be 0 or 1") @@ -659,14 +657,12 @@ def var( _correction = bool(_correction) - # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return paddle.zeros_like(x) if isinstance(axis, int): axis = (axis,) if axis is None: # paddle doesn't support keepdims with axis=None - # (https://github.com/pytorch/pytorch/issues/71209) res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs) res = _axis_none_keepdims(res, x.ndim, keepdims) return res @@ -674,7 +670,6 @@ def var( # paddle.concat doesn't support dim=None -# https://github.com/pytorch/pytorch/issues/70925 def concat( arrays: Union[Tuple[array, ...], List[array]], /, @@ -688,9 +683,6 @@ def concat( return paddle.concat(arrays, axis, **kwargs) -# paddle.squeeze only accepts int dim and doesn't require it -# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was -# added at https://github.com/pytorch/pytorch/pull/89017. def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: if isinstance(axis, int): axis = (axis,) @@ -698,7 +690,7 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: if x.shape[a] != 1: raise ValueError("squeezed dimensions must be equal to 1") axes = _normalize_axes(axis, x.ndim) - # Remove this once pytorch 1.14 is released with the above PR #89017. + sequence = [a - i for i, a in enumerate(axes)] for a in sequence: x = paddle.squeeze(x, a) @@ -712,23 +704,15 @@ def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: # paddle.permute uses dims instead of axes def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: - if len(axes) == 2: - perm = list(range(x.ndim)) - perm[axes[0]], perm[axes[1]] = perm[axes[1]], perm[axes[0]] - axes = perm return paddle.transpose(x, axes) # The axis parameter doesn't work for flip() and roll() -# https://github.com/pytorch/pytorch/issues/71210. Also paddle.flip() doesn't # accept axis=None -def flip( - x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs -) -> array: +def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: if axis is None: axis = tuple(range(x.ndim)) # paddle.flip doesn't accept dim as an int but the method does - # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) @@ -754,19 +738,48 @@ def where(condition: array, x1: array, x2: array, /) -> array: return paddle.where(condition, x1, x2) -# paddle.reshape doesn't have the copy keyword -def reshape( - x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs +def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array: + out = paddle.empty_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array: + out = paddle.zeros_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array: + out = paddle.ones_like(x, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +def full_like( + x: array, + /, + fill_value: bool | int | float | complex, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, ) -> array: - if copy is not None: - raise NotImplementedError("paddle.reshape doesn't yet support the copy keyword") + out = paddle.full_like(x, fill_value, dtype=dtype) + if device is not None: + out = out.to(device) + return out + + +# paddle.reshape doesn't have the copy keyword +def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array: return paddle.reshape(x, shape, **kwargs) # paddle.arange doesn't support returning empty arrays -# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations -# (https://github.com/pytorch/pytorch/issues/70914) def arange( start: Union[int, float], /, @@ -790,7 +803,6 @@ def arange( # paddle.eye does not accept None as a default for the second argument and -# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye( n_rows: int, n_cols: Optional[int] = None, @@ -822,14 +834,11 @@ def linspace( **kwargs, ) -> array: if not endpoint: - return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[ - :-1 - ] + return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1] return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device) # paddle.full does not accept an int size -# https://github.com/pytorch/pytorch/issues/70906 def full( shape: Union[int, Tuple[int, ...]], fill_value: Union[bool, int, float, complex], @@ -886,17 +895,21 @@ def triu(x: array, /, *, k: int = 0) -> array: return paddle.triu(x, k) -# Functions that aren't in paddle https://github.com/pytorch/pytorch/issues/58742 def expand_dims(x: array, /, *, axis: int = 0) -> array: return paddle.unsqueeze(x, axis) -def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: - return x.to(dtype, copy=copy) +def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array: + # if copy is not None: + # raise NotImplementedError("paddle.astype doesn't yet support the copy keyword") + t = x.to(dtype, device=device) + if copy: + t = t.detach().clone() + return t def broadcast_arrays(*arrays: array) -> List[array]: - shape = paddle.broadcast_shapes(*[a.shape for a in arrays]) + shape = broadcast_shapes(*[a.shape for a in arrays]) return [paddle.broadcast_to(a, shape) for a in arrays] @@ -905,28 +918,19 @@ def broadcast_arrays(*arrays: array) -> List[array]: from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult -# https://github.com/pytorch/pytorch/issues/70920 def unique_all(x: array) -> UniqueAllResult: - # paddle.unique doesn't support returning indices. - # https://github.com/pytorch/pytorch/issues/36748. The workaround - # suggested in that issue doesn't actually function correctly (it relies - # on non-deterministic behavior of scatter()). - raise NotImplementedError( - "unique_all() not yet implemented for paddle (see https://github.com/pytorch/pytorch/issues/36748)" + return paddle.unique( + x, + return_index=True, + return_inverse=True, + return_counts=True, ) - # values, inverse_indices, counts = paddle.unique(x, return_counts=True, return_inverse=True) - # # paddle.unique incorrectly gives a 0 count for nan values. - # # https://github.com/pytorch/pytorch/issues/94106 - # counts[paddle.isnan(values)] = 1 - # return UniqueAllResult(values, indices, inverse_indices, counts) - def unique_counts(x: array) -> UniqueCountsResult: values, counts = paddle.unique(x, return_counts=True) # paddle.unique incorrectly gives a 0 count for nan values. - # https://github.com/pytorch/pytorch/issues/94106 counts[paddle.isnan(values)] = 1 return UniqueCountsResult(values, counts) @@ -946,13 +950,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: return paddle.matmul(x1, x2, **kwargs) -matrix_transpose = get_xp(paddle)(_aliases_matrix_transpose) -_vecdot = get_xp(paddle)(_aliases_vecdot) +def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]: + if indexing == "ij": + return paddle.meshgrid(*arrays) + else: + return [i.T for i in paddle.meshgrid(*arrays)] + + +matrix_transpose = paddle.linalg.matrix_transpose def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return _vecdot(x1, x2, axis=axis) + return paddle.linalg.vecdot(x1, x2, axis=axis) # paddle.tensordot uses dims instead of axes @@ -965,7 +975,6 @@ def tensordot( **kwargs, ) -> array: # Note: paddle.tensordot fails with integer dtypes when there is only 1 - # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return paddle.tensordot(x1, x2, axes=axes, **kwargs) @@ -990,16 +999,6 @@ def isdtype( def is_signed(dtype): return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64] - def is_floating_point(dtype): - return dtype in [ - paddle.float32, - paddle.float64, - paddle.float16, - paddle.bfloat16, - paddle.float8_e4m3fn, - paddle.float8_e5m2, - ] - def is_complex(dtype): return dtype in [paddle.complex64, paddle.complex128] @@ -1016,7 +1015,7 @@ def is_complex(dtype): elif kind == "integral": return dtype in _int_dtypes elif kind == "real floating": - return is_floating_point(dtype) + return paddle.is_floating_point(dtype) elif kind == "complex floating": return is_complex(dtype) elif kind == "numeric": @@ -1038,18 +1037,172 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - def sign(x: array, /) -> array: # paddle sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 - if x.dtype.is_complex: + if paddle.is_complex(x): out = x / paddle.abs(x) # sign(0) = 0 but the above formula would give nan out[x == 0 + 0j] = 0 + 0j return out else: out = paddle.sign(x) - if x.dtype.is_floating_point: - out[paddle.isnan(x)] = paddle.nan + if paddle.is_floating_point(x): + out = paddle.where(paddle.isnan(x), paddle.nan, out) return out +def broadcast_shapes(*shapes: List[int]) -> List[int]: + out_shape = shapes[0] + for i, shape in enumerate(shapes): + if i == 0: + continue + out_shape = paddle.broadcast_shape(out_shape, shape) + + return out_shape + + +# asarray also adds the copy keyword, which is not present in numpy 1.0. +def asarray( + obj: Union[ + array, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, + **kwargs, +) -> array: + """ + Array API compatibility wrapper for asarray(). + + See the corresponding documentation in the array library and/or the array API + specification for more details. + """ + if copy is False: + if hasattr(obj, "__dlpack__"): + obj = paddle.from_dlpack(obj.__dlpack__()) + if device is not None: + obj = obj.to(device) + if dtype is not None: + obj = obj.to(dtype) + return obj + else: + raise NotImplementedError( + "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method" + ) + elif copy is True: + obj = np.array(obj, copy=True) + return paddle.to_tensor(obj, dtype=dtype, place=device) + else: + if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype): + obj = np.array(obj, copy=False) + obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype) + if device is not None: + obj = obj.to(device) + return obj + + return obj + + +def clip( + x: array, + /, + min: Optional[Union[int, float, array]] = None, + max: Optional[Union[int, float, array]] = None, +) -> array: + if min is None and max is None: + return x + + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + + min_shape = [] if _isscalar(min) else min.shape + max_shape = [] if _isscalar(max) else max.shape + + result_shape = broadcast_shapes(x.shape, min_shape, max_shape) + + # np.clip does type promotion but the array API clip requires that the + # output have the same dtype as x. We do this instead of just downcasting + # the result of xp.clip() to handle some corner cases better (e.g., + # avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + + # At least handle the case of Python integers correctly (see + # https://github.com/numpy/numpy/pull/26892). + if type(min) is int and min <= paddle.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= paddle.iinfo(x.dtype).max: + max = None + + if out is None: + out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place) + if min is not None: + if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min): + # Avoid loss of precision due to paddle defaulting to float32 + min = paddle.to_tensor(min, dtype=paddle.float64) + a = broadcast_to(paddle.to_tensor(min, place=x.place), result_shape) + ia = (out < a) | paddle.isnan(a) + # paddle requires an explicit cast here + out[ia] = astype(a[ia], out.dtype) + if max is not None: + if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(max): + max = paddle.to_tensor(max, dtype=paddle.float64) + b = broadcast_to(paddle.to_tensor(max, place=x.place), result_shape) + ib = (out > b) | paddle.isnan(b) + out[ib] = astype(b[ib], out.dtype) + # Return a scalar for 0-D + return out[()] + + +def cumulative_sum( + x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False +) -> array: + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + + res = paddle.cumsum(x, axis=axis, dtype=dtype) + + # np.cumsum does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = paddle.concat( + [paddle.zeros(shape=initial_shape, dtype=res.dtype).to(res.place), res], + axis=axis, + ) + return res + + +def searchsorted( + x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None +) -> array: + if sorter is None: + return paddle.searchsorted(x1, x2, right=(side == "right")) + + return paddle.searchsorted( + x1.take_along_axis(axis=-1, indices=sorter), + x2, + right=(side == "right"), + ) + + __all__ = [ "__array_namespace_info__", "result_type", @@ -1129,6 +1282,15 @@ def sign(x: array, /) -> array: "isdtype", "take", "sign", + "broadcast_shapes", + "argmax", + "argmin", + "searchsorted", + "empty_like", + "zeros_like", + "ones_like", + "full_like", + "asarray", ] _all_ignore = ["paddle", "get_xp"] diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py index d8dab7ee..5d29e270 100644 --- a/array_api_compat/paddle/_info.py +++ b/array_api_compat/paddle/_info.py @@ -332,18 +332,12 @@ def devices(self): # message of paddle.device to get the list of all possible types of # device: try: - paddle.device("notadevice") - except RuntimeError as e: + paddle.set_device("notadevice") + except ValueError as e: # The error message is something like: # ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x - devices_names = ( - e.args[0] - .split("ValueError: The device must be a string which is like ")[1] - .split(", ") - ) - devices_names = [ - name.strip("'") for name in devices_names if ":" not in name - ] + devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ") + devices_names = [name.strip("'") for name in devices_names if ":" not in name] # Next we need to check for different indices for different devices. # device(device_name, index=index) doesn't actually check if the diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py index 15519b5a..1442aed8 100644 --- a/array_api_compat/paddle/fft.py +++ b/array_api_compat/paddle/fft.py @@ -4,9 +4,10 @@ if TYPE_CHECKING: import paddle + from ..common._typing import Device array = paddle.Tensor - from typing import Union, Sequence, Literal + from typing import Optional, Union, Sequence, Literal from paddle.fft import * # noqa: F403 import paddle.fft @@ -80,6 +81,32 @@ def ifftshift( return paddle.fft.ifftshift(x, axes=axes, **kwargs) +def fftfreq( + n: int, + /, + *, + d: float = 1.0, + device: Optional[Device] = None, +) -> array: + out = paddle.fft.fftfreq(n, d) + if device is not None: + out = out.to(device) + return out + + +def rfftfreq( + n: int, + /, + *, + d: float = 1.0, + device: Optional[Device] = None, +) -> array: + out = paddle.fft.rfftfreq(n, d) + if device is not None: + out = out.to(device) + return out + + __all__ = paddle.fft.__all__ + [ "fftn", "ifftn", @@ -87,6 +114,8 @@ def ifftshift( "irfftn", "fftshift", "ifftshift", + "fftfreq", + "rfftfreq", ] _all_ignore = ["paddle"] diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py index 7ef04a90..7dd1a266 100644 --- a/array_api_compat/paddle/linalg.py +++ b/array_api_compat/paddle/linalg.py @@ -12,7 +12,9 @@ inf = float("inf") from ._aliases import _fix_promotion, sum +from collections import namedtuple +import paddle from paddle.linalg import * # noqa: F403 # paddle.linalg doesn't define __all__ @@ -23,6 +25,7 @@ # outer is implemented in paddle but aren't in the linalg namespace from paddle import outer +import paddle # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot @@ -30,21 +33,18 @@ # Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3) + # paddle.cross also does not support broadcasting when it would add new # dimensions def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): - raise ValueError( - f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}" - ) + raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") if not (x1.shape[axis] == x2.shape[axis] == 3): - raise ValueError( - f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}" - ) + raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") - x1, x2 = paddle.broadcast_tensors(x1, x2) + x1, x2 = paddle.broadcast_tensors([x1, x2]) return paddle_linalg.cross(x1, x2, axis=axis) @@ -64,7 +64,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: x1_ = paddle.moveaxis(x1, axis, -1) x2_ = paddle.moveaxis(x2, axis, -1) - x1_, x2_ = paddle.broadcast_tensors(x1_, x2_) + x1_, x2_ = paddle.broadcast_tensors([x1_, x2_]) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] @@ -82,9 +82,7 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: # paddle.trace doesn't support the offset argument and doesn't support stacking def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: # Use our wrapped sum to make sure it does upcasting correctly - return sum( - paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype - ) + return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype) def vector_norm( @@ -118,16 +116,44 @@ def vector_norm( return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs) +def matrix_norm( + x: array, + /, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro", +) -> array: + return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims) + + +def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array: + if rtol is None: + return paddle.linalg.pinv(x) + + return paddle.linalg.pinv(x, rcond=rtol) + + +def slogdet(x: array): + det = paddle.linalg.det(x) + sign = paddle.sign(det) + log_det = paddle.log(det) + + slotdet = namedtuple("slotdet", ["sign", "logabsdet"]) + return slotdet(sign, log_det) + + __all__ = linalg_all + [ "outer", "matmul", "matrix_transpose", + "matrix_norm", "tensordot", "cross", "vecdot", "solve", "trace", "vector_norm", + "slogdet", ] _all_ignore = ["paddle_linalg", "sum"]