diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..254e4e61 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -74,6 +74,7 @@ jobs: if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} + ARRAY_API_TESTS_VERSION: 2023.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index bb53e8e5..91c4d9a7 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device +from ._helpers import array_namespace, _check_device, device, is_torch_array # These functions are modified from the NumPy versions. @@ -264,6 +264,38 @@ def var( ) -> ndarray: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) +# cumulative_sum is renamed from cumsum, and adds the include_initial keyword +# argument + +def cumulative_sum( + x: ndarray, + /, + xp, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, + **kwargs +) -> ndarray: + wrapped_xp = array_namespace(x) + + # TODO: The standard is not clear about what should happen when x.ndim == 0. + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + + res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) + + # np.cumsum does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = xp.concatenate( + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + axis=axis, + ) + return res # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. @@ -281,10 +313,11 @@ def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape - result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) wrapped_xp = array_namespace(x) + result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) + # np.clip does type promotion but the array API clip requires that the # output have the same dtype as x. We do this instead of just downcasting # the result of xp.clip() to handle some corner cases better (e.g., @@ -305,20 +338,26 @@ def _isscalar(a): # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= xp.iinfo(x.dtype).min: + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: min = None - if type(max) is int and max >= xp.iinfo(x.dtype).max: + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) + out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), + copy=True, device=device(x)) if min is not None: - a = xp.broadcast_to(xp.asarray(min), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): + # Avoid loss of precision due to torch defaulting to float32 + min = wrapped_xp.asarray(min, dtype=xp.float64) + a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) ia = (out < a) | xp.isnan(a) # torch requires an explicit cast here out[ia] = wrapped_xp.astype(a[ia], out.dtype) if max is not None: - b = xp.broadcast_to(xp.asarray(max), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): + max = wrapped_xp.asarray(max, dtype=xp.float64) + b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) ib = (out > b) | xp.isnan(b) out[ib] = wrapped_xp.astype(b[ib], out.dtype) # Return a scalar for 0-D @@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) -# sum() and prod() should always upcast when dtype=None -def sum( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - # `xp.sum` already upcasts integers, but not floats or complexes - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - -def prod( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) - # ceil, floor, and trunc return integers for integer inputs def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: @@ -521,10 +524,17 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind +# unstack is a new function in the 2023.12 array API standard +def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: + if x.ndim == 0: + raise ValueError("Input array must be at least 1-d.") + return tuple(xp.moveaxis(x, axis, 0)) + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', - 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] + 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', + 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', + 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', + 'unstack'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index dc2b69d8..bfa1f1b9 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 0a15e89d..30ae2943 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -5,6 +5,8 @@ from ..common import _aliases from .._internal import get_xp +from ._info import __array_namespace_info__ + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union @@ -47,14 +49,13 @@ astype = _aliases.astype std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) +cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -sum = get_xp(cp)(_aliases.sum) -prod = get_xp(cp)(_aliases.prod) ceil = get_xp(cp)(_aliases.ceil) floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) @@ -121,14 +122,21 @@ def sign(x: ndarray, /) -> ndarray: vecdot = cp.vecdot else: vecdot = get_xp(cp)(_aliases.vecdot) + if hasattr(cp, 'isdtype'): isdtype = cp.isdtype else: isdtype = get_xp(cp)(_aliases.isdtype) -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'sign'] +if hasattr(cp, 'unstack'): + unstack = cp.unstack +else: + unstack = get_xp(cp)(_aliases.unstack) + +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', + 'atan2', 'atanh', 'bitwise_left_shift', + 'bitwise_invert', 'bitwise_right_shift', + 'concat', 'pow', 'sign'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py new file mode 100644 index 00000000..4440807d --- /dev/null +++ b/array_api_compat/cupy/_info.py @@ -0,0 +1,326 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from cupy import ( + dtype, + cuda, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +class __array_namespace_info__: + """ + Get the array API inspection namespace for CuPy. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for CuPy. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + + __module__ = 'cupy' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for CuPy. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + CuPy. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new CuPy arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new CuPy arrays. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_device() + Device(0) + + """ + return cuda.Device(0) + + def default_dtypes(self, *, device=None): + """ + The default data types used for new CuPy arrays. + + For CuPy, this always returns the following dictionary: + + - **"real floating"**: ``cupy.float64`` + - **"complex floating"**: ``cupy.complex128`` + - **"integral"**: ``cupy.intp`` + - **"indexing"**: ``cupy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new CuPy + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + # TODO: Does this depend on device? + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by CuPy. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + CuPy data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': cupy.int8, + 'int16': cupy.int16, + 'int32': cupy.int32, + 'int64': cupy.int64} + + """ + # TODO: Does this depend on device? + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by CuPy. + + Returns + ------- + devices : list of str + The devices supported by CuPy. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + """ + return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d26ec6a2..cf57c824 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -5,6 +5,8 @@ from ..._internal import get_xp +from ._info import __array_namespace_info__ + import numpy as np from numpy import ( # Constants @@ -42,6 +44,7 @@ import dask.array as da isdtype = get_xp(np)(_aliases.isdtype) +unstack = get_xp(da)(_aliases.unstack) astype = _aliases.astype # Common aliases @@ -88,7 +91,7 @@ def _dask_arange( permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) -clip = get_xp(da)(_aliases.clip) +cumulative_sum = get_xp(da)(_aliases.cumulative_sum) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) @@ -102,8 +105,6 @@ def _dask_arange( vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) @@ -168,17 +169,56 @@ def asarray( concatenate as concat, ) +# dask.array.clip does not work unless all three arguments are provided. +# Furthermore, the masking workaround in common._aliases.clip cannot work with +# dask (meaning uint64 promoting to float64 is going to just be unfixed for +# now). +@get_xp(da) +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, + *, + xp, +) -> Array: + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + + # TODO: This won't handle dask unknown shapes + import numpy as np + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + + if min is not None: + min = xp.broadcast_to(xp.asarray(min), result_shape) + if max is not None: + max = xp.broadcast_to(xp.asarray(max), result_shape) + + if min is None and max is None: + return xp.positive(x) + + if min is None: + return astype(xp.minimum(x, max), x.dtype) + if max is None: + return astype(xp.maximum(x, min), x.dtype) + + return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) + # exclude these from all since _da_unsupported = ['sort', 'argsort'] common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] -__all__ = common_aliases + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', +__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', - 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', - 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'bitwise_right_shift', 'concat', 'pow', 'e', + 'inf', 'nan', 'pi', 'newaxis', 'float32', + 'float64', 'int8', 'int16', 'int32', 'int64', + 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py new file mode 100644 index 00000000..d3b12dc9 --- /dev/null +++ b/array_api_compat/dask/array/_info.py @@ -0,0 +1,345 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from numpy import ( + dtype, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +from ...common._helpers import _DASK_DEVICE + +class __array_namespace_info__: + """ + Get the array API inspection namespace for Dask. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for Dask. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': dask.float64, + 'complex floating': dask.complex128, + 'integral': dask.int64, + 'indexing': dask.int64} + + """ + + __module__ = 'dask.array' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``False`` for Dask. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``False`` for + Dask. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": False, + "data-dependent shapes": False, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new Dask arrays. + + For Dask, this always returns ``'cpu'``. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new Dask arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return "cpu" + + def default_dtypes(self, *, device=None): + """ + The default data types used for new Dask arrays. + + For Dask, this always returns the following dictionary: + + - **"real floating"**: ``numpy.float64`` + - **"complex floating"**: ``numpy.complex128`` + - **"integral"**: ``numpy.intp`` + - **"indexing"**: ``numpy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new Dask + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': dask.float64, + 'complex floating': dask.complex128, + 'integral': dask.int64, + 'indexing': dask.int64} + + """ + if device not in ["cpu", _DASK_DEVICE, None]: + raise ValueError( + 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' + f' {device}' + ) + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by Dask. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + Dask data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': dask.int8, + 'int16': dask.int16, + 'int32': dask.int32, + 'int64': dask.int64} + + """ + if device not in ["cpu", _DASK_DEVICE, None]: + raise ValueError( + 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' + f' {device}' + ) + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by Dask. + + For Dask, this always returns ``['cpu', DASK_DEVICE]``. + + Returns + ------- + devices : list of str + The devices supported by Dask. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + ['cpu', DASK_DEVICE] + + """ + return ["cpu", _DASK_DEVICE] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 7f5b2c6e..49c26d8b 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -5,7 +5,7 @@ # Exports from dask.array.linalg import * # noqa: F403 -from dask.array import trace, outer +from dask.array import outer # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot @@ -42,6 +42,7 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) +trace = get_xp(da)(_linalg.trace) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e29b0751..355215e4 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -4,6 +4,8 @@ from .._internal import get_xp +from ._info import __array_namespace_info__ + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union @@ -47,14 +49,13 @@ astype = _aliases.astype std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) +cumulative_sum = get_xp(np)(_aliases.cumulative_sum) clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) @@ -119,14 +120,21 @@ def asarray( vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) + if hasattr(np, 'isdtype'): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] +if hasattr(np, 'unstack'): + unstack = np.unstack +else: + unstack = get_xp(np)(_aliases.unstack) + +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', + 'atan2', 'atanh', 'bitwise_left_shift', + 'bitwise_invert', 'bitwise_right_shift', + 'concat', 'pow'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py new file mode 100644 index 00000000..62f7ae62 --- /dev/null +++ b/array_api_compat/numpy/_info.py @@ -0,0 +1,346 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from numpy import ( + dtype, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + + +class __array_namespace_info__: + """ + Get the array API inspection namespace for NumPy. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for NumPy. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + + __module__ = 'numpy' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for NumPy. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + NumPy. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new NumPy arrays. + + For NumPy, this always returns ``'cpu'``. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new NumPy arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return "cpu" + + def default_dtypes(self, *, device=None): + """ + The default data types used for new NumPy arrays. + + For NumPy, this always returns the following dictionary: + + - **"real floating"**: ``numpy.float64`` + - **"complex floating"**: ``numpy.complex128`` + - **"integral"**: ``numpy.intp`` + - **"indexing"**: ``numpy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. For NumPy, only + ``'cpu'`` is allowed. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new NumPy + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + if device not in ["cpu", None]: + raise ValueError( + 'Device not understood. Only "cpu" is allowed, but received:' + f' {device}' + ) + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by NumPy. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. For NumPy, only ``'cpu'`` is + allowed. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + NumPy data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + if device not in ["cpu", None]: + raise ValueError( + 'Device not understood. Only "cpu" is allowed, but received:' + f' {device}' + ) + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by NumPy. + + For NumPy, this always returns ``['cpu']``. + + Returns + ------- + devices : list of str + The devices supported by NumPy. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + ['cpu'] + + """ + return ["cpu"] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 899d94fb..5ac66bcb 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,9 +4,15 @@ from builtins import all as _builtin_all, any as _builtin_any from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, clip as _aliases_clip) + vecdot as _aliases_vecdot, + clip as _aliases_clip, + unstack as _aliases_unstack, + cumulative_sum as _aliases_cumulative_sum, + ) from .._internal import get_xp +from ._info import __array_namespace_info__ + import torch from typing import TYPE_CHECKING @@ -165,11 +171,14 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: floor_divide = _two_arg(torch.floor_divide) greater = _two_arg(torch.greater) greater_equal = _two_arg(torch.greater_equal) +hypot = _two_arg(torch.hypot) less = _two_arg(torch.less) less_equal = _two_arg(torch.less_equal) logaddexp = _two_arg(torch.logaddexp) # logical functions are not included here because they only accept bool in the # spec, so type promotion is irrelevant. +maximum = _two_arg(torch.maximum) +minimum = _two_arg(torch.minimum) multiply = _two_arg(torch.multiply) not_equal = _two_arg(torch.not_equal) pow = _two_arg(torch.pow) @@ -193,6 +202,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return torch.amin(x, axis, keepdims=keepdims) clip = get_xp(torch)(_aliases_clip) +unstack = get_xp(torch)(_aliases_unstack) +cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -721,19 +732,21 @@ def sign(x: array, /) -> array: return out -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', - 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', - 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', - 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', - 'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp', +__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', + 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', + 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', + 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', + 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', + 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', - 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', - 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take', 'sign'] + 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum', + 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', + 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', + 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', + 'take', 'sign'] _all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py new file mode 100644 index 00000000..264caa9e --- /dev/null +++ b/array_api_compat/torch/_info.py @@ -0,0 +1,358 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +import torch + +from functools import cache + +class __array_namespace_info__: + """ + Get the array API inspection namespace for PyTorch. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for PyTorch. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + + __module__ = 'torch' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for PyTorch. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + PyTorch. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new PyTorch arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new PyTorch arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return torch.device("cpu") + + def default_dtypes(self, *, device=None): + """ + The default data types used for new PyTorch arrays. + + Parameters + ---------- + device : str, optional + The device to get the default data types for. For PyTorch, only + ``'cpu'`` is allowed. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new PyTorch + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': torch.float32, + 'complex floating': torch.complex64, + 'integral': torch.int64, + 'indexing': torch.int64} + + """ + # Note: if the default is set to float64, the devices like MPS that + # don't support float64 will error. We still return the default_dtype + # value here because this error doesn't represent a different default + # per-device. + default_floating = torch.get_default_dtype() + default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 + default_integral = torch.int64 + return { + "real floating": default_floating, + "complex floating": default_complex, + "integral": default_integral, + "indexing": default_integral, + } + + + def _dtypes(self, kind): + bool = torch.bool + int8 = torch.int8 + int16 = torch.int16 + int32 = torch.int32 + int64 = torch.int64 + uint8 = torch.uint8 + # uint16, uint32, and uint64 are present in newer versions of pytorch, + # but they aren't generally supported by the array API functions, so + # we omit them from this function. + float32 = torch.float32 + float64 = torch.float64 + complex64 = torch.complex64 + complex128 = torch.complex128 + + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + @cache + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by PyTorch. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + PyTorch data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + res = self._dtypes(kind) + for k, v in res.copy().items(): + try: + torch.empty((0,), dtype=v, device=device) + except: + del res[k] + return res + + @cache + def devices(self): + """ + The devices supported by PyTorch. + + Returns + ------- + devices : list of str + The devices supported by PyTorch. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + [device(type='cpu'), device(type='mps', index=0), device(type='meta')] + + """ + # Torch doesn't have a straightforward way to get the list of all + # currently supported devices. To do this, we first parse the error + # message of torch.device to get the list of all possible types of + # device: + try: + torch.device('notadevice') + except RuntimeError as e: + # The error message is something like: + # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" + devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ') + + # Next we need to check for different indices for different devices. + # device(device_name, index=index) doesn't actually check if the + # device name or index is valid. We have to try to create a tensor + # with it (which is why this function is cached). + devices = [] + for device_name in devices_names: + i = 0 + while True: + try: + a = torch.empty((0,), device=torch.device(device_name, index=i)) + if a.device in devices: + break + devices.append(a.device) + except: + break + i += 1 + + return devices diff --git a/dask-xfails.txt b/dask-xfails.txt index 0b651b45..1e9c421c 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -48,9 +48,6 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -# The clip helper uses boolean indexing -array_api_tests/test_operators_and_elementwise_functions.py::test_clip - # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[sorting-sort] @@ -81,8 +78,6 @@ array_api_tests/test_linalg.py::test_svdvals array_api_tests/test_linalg.py::test_cholesky # dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( array_api_tests/test_linalg.py::test_tensordot -# probably same reason for failing as numpy -array_api_tests/test_linalg.py::test_trace # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] array_api_tests/test_linalg.py::test_linalg_tensordot @@ -155,3 +150,8 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod + +# 2023.12 support +array_api_tests/test_manipulation_functions.py::test_repeat +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[astype] diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index d921bd97..d6a2c251 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -119,6 +119,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_copysign array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] @@ -131,11 +132,14 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[f array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp +array_api_tests/test_operators_and_elementwise_functions.py::test_maximum +array_api_tests/test_operators_and_elementwise_functions.py::test_minimum array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] @@ -246,3 +250,9 @@ array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] + +# 2023.12 support +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 40c6cbc4..30b5d302 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -45,3 +45,9 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod + +# 2023.12 support +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 8de58bb4..68f954d2 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -11,8 +11,15 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod +array_api_tests/test_statistical_functions.py::test_cumulative_sum # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# 2023.12 support +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 6aa54bb4..cdcd2576 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -37,3 +37,9 @@ array_api_tests/test_statistical_functions.py::test_prod # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# 2023.12 support +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] diff --git a/ruff.toml b/ruff.toml index 53e8596c..72e111b5 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,5 +9,9 @@ select = [ "PLC0414" ] -# Ignore module import not at top of file -ignore = ["E402"] +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] diff --git a/tests/test_all.py b/tests/test_all.py index ff01fbdd..969d5cfb 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -32,6 +32,8 @@ def test_all(library): continue dir_names = [n for n in dir(module) if not n.startswith('_')] + if '__array_namespace_info__' in dir(module): + dir_names.append('__array_namespace_info__') ignore_all_names = getattr(module, '_all_ignore', []) ignore_all_names += ['annotations', 'TYPE_CHECKING'] dir_names = set(dir_names) - set(ignore_all_names) diff --git a/torch-xfails.txt b/torch-xfails.txt index aedbc4a7..c7abe2e9 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -189,3 +189,14 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +array_api_tests/test_manipulation_functions.py::test_repeat +array_api_tests/test_signatures.py::test_func_signature[repeat] +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +# Argument 'max_version' missing from signature +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[astype]