From 0734064094f6e86cc610579857858f796afc5757 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Jul 2024 13:05:36 -0600 Subject: [PATCH 01/27] Remove floating-point promotion from sum, prod, and trace Fixes #152 --- array_api_compat/common/_aliases.py | 42 ++----------------------- array_api_compat/common/_linalg.py | 5 --- array_api_compat/cupy/_aliases.py | 2 -- array_api_compat/dask/array/_aliases.py | 2 -- array_api_compat/numpy/_aliases.py | 2 -- 5 files changed, 3 insertions(+), 50 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index bb53e8e5..ec43ae31 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -389,42 +389,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) -# sum() and prod() should always upcast when dtype=None -def sum( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - # `xp.sum` already upcasts integers, but not floats or complexes - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - -def prod( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) - # ceil, floor, and trunc return integers for integer inputs def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: @@ -525,6 +489,6 @@ def isdtype( 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', - 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] + 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', + 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index dc2b69d8..bfa1f1b9 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index d7e78fdd..68ff378a 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -53,8 +53,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -sum = get_xp(cp)(_aliases.sum) -prod = get_xp(cp)(_aliases.prod) ceil = get_xp(cp)(_aliases.ceil) floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d26ec6a2..5d89aa1c 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -102,8 +102,6 @@ def _dask_arange( vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e29b0751..4fd6a68a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -53,8 +53,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) From 25da0e443078e661efe37590ca2003e80260488c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Jul 2024 13:09:22 -0600 Subject: [PATCH 02/27] Wrap trace for dask --- array_api_compat/dask/array/linalg.py | 3 ++- dask-xfails.txt | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 7f5b2c6e..49c26d8b 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -5,7 +5,7 @@ # Exports from dask.array.linalg import * # noqa: F403 -from dask.array import trace, outer +from dask.array import outer # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot @@ -42,6 +42,7 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) +trace = get_xp(da)(_linalg.trace) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) diff --git a/dask-xfails.txt b/dask-xfails.txt index 0b651b45..f4c96995 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -81,8 +81,6 @@ array_api_tests/test_linalg.py::test_svdvals array_api_tests/test_linalg.py::test_cholesky # dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( array_api_tests/test_linalg.py::test_tensordot -# probably same reason for failing as numpy -array_api_tests/test_linalg.py::test_trace # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] array_api_tests/test_linalg.py::test_linalg_tensordot From e3834418f7aa20398a1723fd90648fc0e912041b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Jul 2024 13:12:46 -0600 Subject: [PATCH 03/27] Wrap hypot in torch --- array_api_compat/torch/_aliases.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index c2be21fe..ad169901 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -162,6 +162,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: floor_divide = _two_arg(torch.floor_divide) greater = _two_arg(torch.greater) greater_equal = _two_arg(torch.greater_equal) +hypot = _two_arg(torch.hypot) less = _two_arg(torch.less) less_equal = _two_arg(torch.less_equal) logaddexp = _two_arg(torch.logaddexp) @@ -707,15 +708,15 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', - 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', - 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', - 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', - 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', + 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', + 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', + 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', + 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', + 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', + 'UniqueInverseResult', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', + 'vecdot', 'tensordot', 'isdtype', 'take'] _all_ignore = ['torch', 'get_xp'] From 5eb7abfdc94c0b9a2fe75a2236542c2a31c43aa0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Jul 2024 13:19:29 -0600 Subject: [PATCH 04/27] Add unstack to all wrapped libraries --- array_api_compat/common/_aliases.py | 8 +++++++- array_api_compat/cupy/_aliases.py | 6 ++++++ array_api_compat/dask/array/_aliases.py | 1 + array_api_compat/numpy/_aliases.py | 6 ++++++ array_api_compat/torch/_aliases.py | 6 ++++-- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index ec43ae31..291fae8a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -485,10 +485,16 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind +# unstack is a new function in the 2023.12 array API standard +def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: + if x.ndim == 0: + raise ValueError("Input array must be at least 1-d.") + return tuple(xp.moveaxis(x, axis, 0)) + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 68ff378a..1ffba901 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -112,11 +112,17 @@ def asarray( vecdot = cp.vecdot else: vecdot = get_xp(cp)(_aliases.vecdot) + if hasattr(cp, 'isdtype'): isdtype = cp.isdtype else: isdtype = get_xp(cp)(_aliases.isdtype) +if hasattr(cp, 'unstack'): + unstack = cp.unstack +else: + unstack = get_xp(cp)(_aliases.unstack) + __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 5d89aa1c..7519f59f 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -42,6 +42,7 @@ import dask.array as da isdtype = get_xp(np)(_aliases.isdtype) +unstack = get_xp(da)(_aliases.unstack) astype = _aliases.astype # Common aliases diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 4fd6a68a..e77e4eb9 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -117,11 +117,17 @@ def asarray( vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) + if hasattr(np, 'isdtype'): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) +if hasattr(np, 'unstack'): + unstack = np.unstack +else: + unstack = get_xp(np)(_aliases.unstack) + __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index ad169901..ec4d5253 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,7 +4,8 @@ from builtins import all as _builtin_all, any as _builtin_any from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, clip as _aliases_clip) + vecdot as _aliases_vecdot, clip as + _aliases_clip, unstack as _aliases_unstack,) from .._internal import get_xp import torch @@ -191,6 +192,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return torch.amin(x, axis, keepdims=keepdims) clip = get_xp(torch)(_aliases_clip) +unstack = get_xp(torch)(_aliases_unstack) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -709,7 +711,7 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', - 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', + 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', From e70bcc86b4e567eb8220c4763d0381c38225fae8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Jul 2024 14:39:39 -0600 Subject: [PATCH 05/27] Run the test suite against the 2023 version of the standard on CI --- .github/workflows/array-api-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..254e4e61 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -74,6 +74,7 @@ jobs: if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} + ARRAY_API_TESTS_VERSION: 2023.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak From d751db6568cf75a6a1eed8e2c53ab79c034da837 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 30 Jul 2024 12:27:16 -0600 Subject: [PATCH 06/27] Update xfails for failing 2023 tests --- dask-xfails.txt | 17 +++++++++++++++++ numpy-1-21-xfails.txt | 21 +++++++++++++++++++++ numpy-1-26-xfails.txt | 18 ++++++++++++++++++ numpy-dev-xfails.txt | 1 + numpy-xfails.txt | 18 ++++++++++++++++++ 5 files changed, 75 insertions(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index f4c96995..b601da45 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -153,3 +153,20 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] +array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] +array_api_tests/test_inspection_functions.py::test_array_namespace_info +array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes +array_api_tests/test_manipulation_functions.py::test_repeat +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] +array_api_tests/test_signatures.py::test_info_func_signature[capabilities] +array_api_tests/test_signatures.py::test_info_func_signature[default_device] +array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_info_func_signature[devices] +array_api_tests/test_signatures.py::test_info_func_signature[dtypes] +array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index d921bd97..09767108 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -119,6 +119,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_copysign array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] @@ -131,11 +132,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[f array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp +array_api_tests/test_operators_and_elementwise_functions.py::test_minimum array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] @@ -246,3 +249,21 @@ array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] +array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] +array_api_tests/test_inspection_functions.py::test_array_namespace_info +array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +array_api_tests/test_signatures.py::test_info_func_signature[capabilities] +array_api_tests/test_signatures.py::test_info_func_signature[default_device] +array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_info_func_signature[devices] +array_api_tests/test_signatures.py::test_info_func_signature[dtypes] +array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 40c6cbc4..81efc942 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -45,3 +45,21 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] +array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] +array_api_tests/test_inspection_functions.py::test_array_namespace_info +array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +array_api_tests/test_signatures.py::test_info_func_signature[capabilities] +array_api_tests/test_signatures.py::test_info_func_signature[default_device] +array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_info_func_signature[devices] +array_api_tests/test_signatures.py::test_info_func_signature[dtypes] +array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 6aa54bb4..8908910c 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -32,6 +32,7 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod +array_api_tests/test_statistical_functions.py::test_cumulative_sum # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 6aa54bb4..ce6fdb15 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -37,3 +37,21 @@ array_api_tests/test_statistical_functions.py::test_prod # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] +array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] +array_api_tests/test_inspection_functions.py::test_array_namespace_info +array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes +array_api_tests/test_searching_functions.py::test_searchsorted +array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +array_api_tests/test_signatures.py::test_info_func_signature[capabilities] +array_api_tests/test_signatures.py::test_info_func_signature[default_device] +array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_info_func_signature[devices] +array_api_tests/test_signatures.py::test_info_func_signature[dtypes] +array_api_tests/test_statistical_functions.py::test_cumulative_sum From 5f8b5d615423926e8a517d1f4b03ec42fbd63dcc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 12 Aug 2024 15:49:25 -0600 Subject: [PATCH 07/27] Improvements to the clip wrapper - Ensure the arrays that are created are created on the same device as x. (fixes #177) - Make clip() work with dask.array. The workaround avoid uint64 -> float64 promotion does not work here. (fixes #176) - Fix loss of precision when clipping a float64 tensor with torch due to the scalar being converted to a float32 tensor. --- array_api_compat/common/_aliases.py | 21 +++++++++----- array_api_compat/dask/array/_aliases.py | 38 ++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 291fae8a..7b23c910 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device +from ._helpers import array_namespace, _check_device, device, is_torch_array # These functions are modified from the NumPy versions. @@ -281,10 +281,11 @@ def _isscalar(a): return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape - result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) wrapped_xp = array_namespace(x) + result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) + # np.clip does type promotion but the array API clip requires that the # output have the same dtype as x. We do this instead of just downcasting # the result of xp.clip() to handle some corner cases better (e.g., @@ -305,20 +306,26 @@ def _isscalar(a): # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= xp.iinfo(x.dtype).min: + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: min = None - if type(max) is int and max >= xp.iinfo(x.dtype).max: + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) + out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), + copy=True, device=device(x)) if min is not None: - a = xp.broadcast_to(xp.asarray(min), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): + # Avoid loss of precision due to torch defaulting to float32 + min = wrapped_xp.asarray(min, dtype=xp.float64) + a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) ia = (out < a) | xp.isnan(a) # torch requires an explicit cast here out[ia] = wrapped_xp.astype(a[ia], out.dtype) if max is not None: - b = xp.broadcast_to(xp.asarray(max), result_shape) + if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): + max = wrapped_xp.asarray(max, dtype=xp.float64) + b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) ib = (out > b) | xp.isnan(b) out[ib] = wrapped_xp.astype(b[ib], out.dtype) # Return a scalar for 0-D diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 7519f59f..10a03cd8 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -89,7 +89,6 @@ def _dask_arange( permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) -clip = get_xp(da)(_aliases.clip) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) @@ -167,6 +166,43 @@ def asarray( concatenate as concat, ) +# dask.array.clip does not work unless all three arguments are provided. +# Furthermore, the masking workaround in common._aliases.clip cannot work with +# dask (meaning uint64 promoting to float64 is going to just be unfixed for +# now). +@get_xp(da) +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, + *, + xp, +) -> Array: + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + + # TODO: This won't handle dask unknown shapes + import numpy as np + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + + if min is not None: + min = xp.broadcast_to(xp.asarray(min), result_shape) + if max is not None: + max = xp.broadcast_to(xp.asarray(max), result_shape) + + if min is None and max is None: + return xp.positive(x) + + if min is None: + return astype(xp.minimum(x, max), x.dtype) + if max is None: + return astype(xp.maximum(x, min), x.dtype) + + return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) + # exclude these from all since _da_unsupported = ['sort', 'argsort'] From 2995a0f59cc95c2bc9704fd4b9b2e6a710395358 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 12 Aug 2024 15:53:58 -0600 Subject: [PATCH 08/27] Remove clip from dask xfails --- dask-xfails.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/dask-xfails.txt b/dask-xfails.txt index b601da45..77f99dc7 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -48,9 +48,6 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -# The clip helper uses boolean indexing -array_api_tests/test_operators_and_elementwise_functions.py::test_clip - # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[sorting-sort] From 284bd99790bc9b26b142352b8da8ecbb252c35ab Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 13 Aug 2024 11:59:53 -0600 Subject: [PATCH 09/27] Update numpy 1.21 xfails --- numpy-1-21-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 09767108..3e858141 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -138,6 +138,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp +array_api_tests/test_operators_and_elementwise_functions.py::test_maximum array_api_tests/test_operators_and_elementwise_functions.py::test_minimum array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] From 11cb6ef8446a0703fe22e40f0a168d70f539aefa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 19 Aug 2024 15:34:01 -0600 Subject: [PATCH 10/27] Add NumPy inspection namespace --- array_api_compat/numpy/_aliases.py | 11 +- array_api_compat/numpy/_info.py | 346 +++++++++++++++++++++++++++++ 2 files changed, 353 insertions(+), 4 deletions(-) create mode 100644 array_api_compat/numpy/_info.py diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e77e4eb9..14d0960b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -4,6 +4,8 @@ from .._internal import get_xp +from ._info import __array_namespace_info__ + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union @@ -128,9 +130,10 @@ def asarray( else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', + 'atan2', 'atanh', 'bitwise_left_shift', + 'bitwise_invert', 'bitwise_right_shift', + 'concat', 'pow'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py new file mode 100644 index 00000000..62f7ae62 --- /dev/null +++ b/array_api_compat/numpy/_info.py @@ -0,0 +1,346 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from numpy import ( + dtype, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + + +class __array_namespace_info__: + """ + Get the array API inspection namespace for NumPy. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for NumPy. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + + __module__ = 'numpy' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for NumPy. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + NumPy. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new NumPy arrays. + + For NumPy, this always returns ``'cpu'``. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new NumPy arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return "cpu" + + def default_dtypes(self, *, device=None): + """ + The default data types used for new NumPy arrays. + + For NumPy, this always returns the following dictionary: + + - **"real floating"**: ``numpy.float64`` + - **"complex floating"**: ``numpy.complex128`` + - **"integral"**: ``numpy.intp`` + - **"indexing"**: ``numpy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. For NumPy, only + ``'cpu'`` is allowed. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new NumPy + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + if device not in ["cpu", None]: + raise ValueError( + 'Device not understood. Only "cpu" is allowed, but received:' + f' {device}' + ) + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by NumPy. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. For NumPy, only ``'cpu'`` is + allowed. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + NumPy data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + if device not in ["cpu", None]: + raise ValueError( + 'Device not understood. Only "cpu" is allowed, but received:' + f' {device}' + ) + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by NumPy. + + For NumPy, this always returns ``['cpu']``. + + Returns + ------- + devices : list of str + The devices supported by NumPy. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + ['cpu'] + + """ + return ["cpu"] From 4c9dd0e509060ca4d38bfc77713459f930252298 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 19 Aug 2024 15:59:45 -0600 Subject: [PATCH 11/27] Add CuPy inspection APIs I'm not sure if all the details here are correct. See https://github.com/data-apis/array-api-compat/issues/127#issuecomment-2297514930. --- array_api_compat/cupy/_aliases.py | 11 +- array_api_compat/cupy/_info.py | 326 ++++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+), 4 deletions(-) create mode 100644 array_api_compat/cupy/_info.py diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 1ffba901..fad41102 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -5,6 +5,8 @@ from ..common import _aliases from .._internal import get_xp +from ._info import __array_namespace_info__ + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union @@ -123,9 +125,10 @@ def asarray( else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', + 'atan2', 'atanh', 'bitwise_left_shift', + 'bitwise_invert', 'bitwise_right_shift', + 'concat', 'pow'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py new file mode 100644 index 00000000..4440807d --- /dev/null +++ b/array_api_compat/cupy/_info.py @@ -0,0 +1,326 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from cupy import ( + dtype, + cuda, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +class __array_namespace_info__: + """ + Get the array API inspection namespace for CuPy. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for CuPy. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + + __module__ = 'cupy' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for CuPy. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + CuPy. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new CuPy arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new CuPy arrays. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_device() + Device(0) + + """ + return cuda.Device(0) + + def default_dtypes(self, *, device=None): + """ + The default data types used for new CuPy arrays. + + For CuPy, this always returns the following dictionary: + + - **"real floating"**: ``cupy.float64`` + - **"complex floating"**: ``cupy.complex128`` + - **"integral"**: ``cupy.intp`` + - **"indexing"**: ``cupy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new CuPy + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + # TODO: Does this depend on device? + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by CuPy. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + CuPy data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': cupy.int8, + 'int16': cupy.int16, + 'int32': cupy.int32, + 'int64': cupy.int64} + + """ + # TODO: Does this depend on device? + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by CuPy. + + Returns + ------- + devices : list of str + The devices supported by CuPy. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + """ + return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] From d3c4b3c6dbf2e0b7156abd2acbe64693b595155b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 11 Sep 2024 15:34:20 -0600 Subject: [PATCH 12/27] Add inspection namespace for torch Some of these things have to be inspected manually, and I'm not completely certain everything here is correct. --- array_api_compat/torch/_aliases.py | 15 +- array_api_compat/torch/_info.py | 369 +++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+), 6 deletions(-) create mode 100644 array_api_compat/torch/_info.py diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 0675944b..1c1ce26b 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -8,6 +8,8 @@ _aliases_clip, unstack as _aliases_unstack,) from .._internal import get_xp +from ._info import __array_namespace_info__ + import torch from typing import TYPE_CHECKING @@ -724,12 +726,13 @@ def sign(x: array, /) -> array: return out -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', - 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', - 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', - 'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', - 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort', +__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', + 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', + 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', + 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', + 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', + 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', + 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py new file mode 100644 index 00000000..d5fe1bb5 --- /dev/null +++ b/array_api_compat/torch/_info.py @@ -0,0 +1,369 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from torch import ( + asarray, + get_default_dtype, + device, + empty, + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +from functools import cache + +class __array_namespace_info__: + """ + Get the array API inspection namespace for PyTorch. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for PyTorch. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': numpy.float64, + 'complex floating': numpy.complex128, + 'integral': numpy.int64, + 'indexing': numpy.int64} + + """ + + __module__ = 'torch' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for PyTorch. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + PyTorch. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new PyTorch arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new PyTorch arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return device("cpu") + + def default_dtypes(self, *, device=None): + """ + The default data types used for new PyTorch arrays. + + Parameters + ---------- + device : str, optional + The device to get the default data types for. For PyTorch, only + ``'cpu'`` is allowed. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new PyTorch + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': torch.float32, + 'complex floating': torch.complex64, + 'integral': torch.int64, + 'indexing': torch.int64} + + """ + default_floating = get_default_dtype() + default_complex = complex64 if default_floating == float32 else complex128 + default_integral = asarray(0, device=device).dtype + return { + "real floating": default_floating, + "complex floating": default_complex, + "integral": default_integral, + "indexing": default_integral, + } + + @cache + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by PyTorch. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + PyTorch data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + res = self._dtypes(kind) + for k, v in res.copy().items(): + try: + empty((0,), dtype=v, device=device) + except: + del res[k] + return res + + def _dtypes(self, kind): + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + @cache + def devices(self): + """ + The devices supported by PyTorch. + + Returns + ------- + devices : list of str + The devices supported by PyTorch. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + [device(type='cpu'), device(type='mps', index=0), device(type='meta')] + + """ + # Torch doesn't have a straightforward way to get the list of all + # currently supported devices. To do this, we first parse the error + # message of torch.device to get the list of all possible types of + # device: + try: + device('notadevice') + except RuntimeError as e: + # The error message is something like: + # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" + devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ') + + # Next we need to check for different indices for different devices. + # device(device_name, index=index) doesn't actually check if the + # device name or index is valid. We have to try to create a tensor + # with it (which is why this function is cached). + devices = [] + for device_name in devices_names: + i = 0 + while True: + try: + a = empty((0,), device=device(device_name, index=i)) + if a.device in devices: + break + devices.append(a.device) + except: + break + i += 1 + + return devices From be4fa68f7235654c65e89179b6a7301e5bbc3e18 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 11 Sep 2024 15:40:48 -0600 Subject: [PATCH 13/27] Handle torch versions that do not have uint dtypes --- array_api_compat/torch/_info.py | 174 ++++++++++++++++---------------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index d5fe1bb5..2eab72a4 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -7,25 +7,7 @@ more details. """ -from torch import ( - asarray, - get_default_dtype, - device, - empty, - bool, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - complex64, - complex128, -) +import torch from functools import cache @@ -130,7 +112,7 @@ def default_device(self): 'cpu' """ - return device("cpu") + return torch.device("cpu") def default_dtypes(self, *, device=None): """ @@ -165,9 +147,9 @@ def default_dtypes(self, *, device=None): 'indexing': torch.int64} """ - default_floating = get_default_dtype() - default_complex = complex64 if default_floating == float32 else complex128 - default_integral = asarray(0, device=device).dtype + default_floating = torch.get_default_dtype() + default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 + default_integral = torch.asarray(0, device=device).dtype return { "real floating": default_floating, "complex floating": default_complex, @@ -175,70 +157,22 @@ def default_dtypes(self, *, device=None): "indexing": default_integral, } - @cache - def dtypes(self, *, device=None, kind=None): - """ - The array API data types supported by PyTorch. - - Note that this function only returns data types that are defined by - the array API. - - Parameters - ---------- - device : str, optional - The device to get the data types for. - kind : str or tuple of str, optional - The kind of data types to return. If ``None``, all data types are - returned. If a string, only data types of that kind are returned. - If a tuple, a dictionary containing the union of the given kinds - is returned. The following kinds are supported: - - - ``'bool'``: boolean data types (i.e., ``bool``). - - ``'signed integer'``: signed integer data types (i.e., ``int8``, - ``int16``, ``int32``, ``int64``). - - ``'unsigned integer'``: unsigned integer data types (i.e., - ``uint8``, ``uint16``, ``uint32``, ``uint64``). - - ``'integral'``: integer data types. Shorthand for ``('signed - integer', 'unsigned integer')``. - - ``'real floating'``: real-valued floating-point data types - (i.e., ``float32``, ``float64``). - - ``'complex floating'``: complex floating-point data types (i.e., - ``complex64``, ``complex128``). - - ``'numeric'``: numeric data types. Shorthand for ``('integral', - 'real floating', 'complex floating')``. - - Returns - ------- - dtypes : dict - A dictionary mapping the names of data types to the corresponding - PyTorch data types. - - See Also - -------- - __array_namespace_info__.capabilities, - __array_namespace_info__.default_device, - __array_namespace_info__.default_dtypes, - __array_namespace_info__.devices - - Examples - -------- - >>> info = np.__array_namespace_info__() - >>> info.dtypes(kind='signed integer') - {'int8': numpy.int8, - 'int16': numpy.int16, - 'int32': numpy.int32, - 'int64': numpy.int64} - - """ - res = self._dtypes(kind) - for k, v in res.copy().items(): - try: - empty((0,), dtype=v, device=device) - except: - del res[k] - return res def _dtypes(self, kind): + bool = torch.bool + int8 = torch.int8 + int16 = torch.int16 + int32 = torch.int32 + int64 = torch.int64 + uint8 = getattr(torch, "uint8", None) + uint16 = getattr(torch, "uint16", None) + uint32 = getattr(torch, "uint32", None) + uint64 = getattr(torch, "uint64", None) + float32 = torch.float32 + float64 = torch.float64 + complex64 = torch.complex64 + complex128 = torch.complex128 + if kind is None: return { "bool": bool, @@ -314,6 +248,72 @@ def _dtypes(self, kind): return res raise ValueError(f"unsupported kind: {kind!r}") + @cache + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by PyTorch. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + PyTorch data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': numpy.int8, + 'int16': numpy.int16, + 'int32': numpy.int32, + 'int64': numpy.int64} + + """ + res = self._dtypes(kind) + for k, v in res.copy().items(): + if v is None: + del res[k] + continue + try: + torch.empty((0,), dtype=v, device=device) + except: + del res[k] + return res + @cache def devices(self): """ @@ -343,7 +343,7 @@ def devices(self): # message of torch.device to get the list of all possible types of # device: try: - device('notadevice') + torch.device('notadevice') except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" @@ -358,7 +358,7 @@ def devices(self): i = 0 while True: try: - a = empty((0,), device=device(device_name, index=i)) + a = torch.empty((0,), device=torch.device(device_name, index=i)) if a.device in devices: break devices.append(a.device) From 774d17539d63a64ba0c898cfeb3dc455d4b9c40d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 12 Sep 2024 14:12:25 -0600 Subject: [PATCH 14/27] Remove uint16, uint32, and uint64 from the pytorch dtypes() output --- array_api_compat/torch/_info.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 2eab72a4..5211d7de 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -164,10 +164,10 @@ def _dtypes(self, kind): int16 = torch.int16 int32 = torch.int32 int64 = torch.int64 - uint8 = getattr(torch, "uint8", None) - uint16 = getattr(torch, "uint16", None) - uint32 = getattr(torch, "uint32", None) - uint64 = getattr(torch, "uint64", None) + uint8 = torch.uint8 + # uint16, uint32, and uint64 are present in newer versions of pytorch, + # but they aren't generally supported by the array API functions, so + # we omit them from this function. float32 = torch.float32 float64 = torch.float64 complex64 = torch.complex64 @@ -181,9 +181,6 @@ def _dtypes(self, kind): "int32": int32, "int64": int64, "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, "float32": float32, "float64": float64, "complex64": complex64, @@ -201,9 +198,6 @@ def _dtypes(self, kind): if kind == "unsigned integer": return { "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, } if kind == "integral": return { @@ -212,9 +206,6 @@ def _dtypes(self, kind): "int32": int32, "int64": int64, "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, } if kind == "real floating": return { @@ -233,9 +224,6 @@ def _dtypes(self, kind): "int32": int32, "int64": int64, "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, "float32": float32, "float64": float64, "complex64": complex64, @@ -305,9 +293,6 @@ def dtypes(self, *, device=None, kind=None): """ res = self._dtypes(kind) for k, v in res.copy().items(): - if v is None: - del res[k] - continue try: torch.empty((0,), dtype=v, device=device) except: From 231ef9594717e341aaa31de12a2092e6c20b5eba Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 12 Sep 2024 14:26:31 -0600 Subject: [PATCH 15/27] Add inspection namespace for dask --- array_api_compat/dask/array/_aliases.py | 16 +- array_api_compat/dask/array/_info.py | 345 ++++++++++++++++++++++++ 2 files changed, 355 insertions(+), 6 deletions(-) create mode 100644 array_api_compat/dask/array/_info.py diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 10a03cd8..b89cca52 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -5,6 +5,8 @@ from ..._internal import get_xp +from ._info import __array_namespace_info__ + import numpy as np from numpy import ( # Constants @@ -208,12 +210,14 @@ def _isscalar(a): common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] -__all__ = common_aliases + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', +__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', - 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', - 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'bitwise_right_shift', 'concat', 'pow', 'e', + 'inf', 'nan', 'pi', 'newaxis', 'float32', + 'float64', 'int8', 'int16', 'int32', 'int64', + 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py new file mode 100644 index 00000000..d3b12dc9 --- /dev/null +++ b/array_api_compat/dask/array/_info.py @@ -0,0 +1,345 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from numpy import ( + dtype, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +from ...common._helpers import _DASK_DEVICE + +class __array_namespace_info__: + """ + Get the array API inspection namespace for Dask. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for Dask. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': dask.float64, + 'complex floating': dask.complex128, + 'integral': dask.int64, + 'indexing': dask.int64} + + """ + + __module__ = 'dask.array' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``False`` for Dask. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``False`` for + Dask. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": False, + "data-dependent shapes": False, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new Dask arrays. + + For Dask, this always returns ``'cpu'``. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new Dask arrays. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_device() + 'cpu' + + """ + return "cpu" + + def default_dtypes(self, *, device=None): + """ + The default data types used for new Dask arrays. + + For Dask, this always returns the following dictionary: + + - **"real floating"**: ``numpy.float64`` + - **"complex floating"**: ``numpy.complex128`` + - **"integral"**: ``numpy.intp`` + - **"indexing"**: ``numpy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new Dask + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': dask.float64, + 'complex floating': dask.complex128, + 'integral': dask.int64, + 'indexing': dask.int64} + + """ + if device not in ["cpu", _DASK_DEVICE, None]: + raise ValueError( + 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' + f' {device}' + ) + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by Dask. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + Dask data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': dask.int8, + 'int16': dask.int16, + 'int32': dask.int32, + 'int64': dask.int64} + + """ + if device not in ["cpu", _DASK_DEVICE, None]: + raise ValueError( + 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' + f' {device}' + ) + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by Dask. + + For Dask, this always returns ``['cpu', DASK_DEVICE]``. + + Returns + ------- + devices : list of str + The devices supported by Dask. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.devices() + ['cpu', DASK_DEVICE] + + """ + return ["cpu", _DASK_DEVICE] From 92ebdffe1d421131832a93d2851c2f104d6a6d31 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 18 Sep 2024 15:11:56 -0600 Subject: [PATCH 16/27] Add cumulative_sum wrapper for the numpy-likes --- array_api_compat/common/_aliases.py | 37 +++++++++++++++++++++++-- array_api_compat/cupy/_aliases.py | 1 + array_api_compat/dask/array/_aliases.py | 1 + array_api_compat/numpy/_aliases.py | 1 + 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 7b23c910..a3437526 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -264,6 +264,36 @@ def var( ) -> ndarray: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) +# cumulative_sum is renamed from cumsum, and adds the include_initial keyword +# argument + +def cumulative_sum( + x: ndarray, + /, + xp, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, + **kwargs +) -> ndarray: + # TODO: The standard is not clear about what should happen when x.ndim == 0. + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + + res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) + + # np.cumsum does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = xp.concatenate( + [xp.zeros_like(res, shape=initial_shape), res], + axis=axis, + ) + return res # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. @@ -502,6 +532,7 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', - 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack'] + 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', + 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', + 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', + 'unstack'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 4de1326a..30ae2943 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -49,6 +49,7 @@ astype = _aliases.astype std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) +cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index b89cca52..cf57c824 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -91,6 +91,7 @@ def _dask_arange( permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) +cumulative_sum = get_xp(da)(_aliases.cumulative_sum) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 14d0960b..355215e4 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -49,6 +49,7 @@ astype = _aliases.astype std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) +cumulative_sum = get_xp(np)(_aliases.cumulative_sum) clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) From 176a66ae87c0cf41975b70d17c0b020b05a28066 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 18 Sep 2024 15:15:45 -0600 Subject: [PATCH 17/27] Remove a bunch of 2023.12 xfails --- dask-xfails.txt | 12 ------------ numpy-1-21-xfails.txt | 12 ------------ numpy-1-26-xfails.txt | 12 ------------ numpy-xfails.txt | 12 ------------ 4 files changed, 48 deletions(-) diff --git a/dask-xfails.txt b/dask-xfails.txt index 77f99dc7..1e9c421c 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -152,18 +152,6 @@ array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] -array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] -array_api_tests/test_inspection_functions.py::test_array_namespace_info -array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] -array_api_tests/test_signatures.py::test_info_func_signature[capabilities] -array_api_tests/test_signatures.py::test_info_func_signature[default_device] -array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_info_func_signature[devices] -array_api_tests/test_signatures.py::test_info_func_signature[dtypes] -array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 3e858141..d6a2c251 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -252,19 +252,7 @@ array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] -array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] -array_api_tests/test_inspection_functions.py::test_array_namespace_info -array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_info_func_signature[capabilities] -array_api_tests/test_signatures.py::test_info_func_signature[default_device] -array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_info_func_signature[devices] -array_api_tests/test_signatures.py::test_info_func_signature[dtypes] -array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 81efc942..30b5d302 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -47,19 +47,7 @@ array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] -array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] -array_api_tests/test_inspection_functions.py::test_array_namespace_info -array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_info_func_signature[capabilities] -array_api_tests/test_signatures.py::test_info_func_signature[default_device] -array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_info_func_signature[devices] -array_api_tests/test_signatures.py::test_info_func_signature[dtypes] -array_api_tests/test_statistical_functions.py::test_cumulative_sum diff --git a/numpy-xfails.txt b/numpy-xfails.txt index ce6fdb15..cdcd2576 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -39,19 +39,7 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] -array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__] -array_api_tests/test_inspection_functions.py::test_array_namespace_info -array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_func_signature[astype] -array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_info_func_signature[capabilities] -array_api_tests/test_signatures.py::test_info_func_signature[default_device] -array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_info_func_signature[devices] -array_api_tests/test_signatures.py::test_info_func_signature[dtypes] -array_api_tests/test_statistical_functions.py::test_cumulative_sum From 46a2227be326bace56ce08bbb4a675bffd6db4f0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 14:22:00 -0600 Subject: [PATCH 18/27] Hard-code the default torch integral type to int64 See the discussion at https://github.com/data-apis/array-api-compat/pull/166#discussion_r1755724807 --- array_api_compat/torch/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 5211d7de..a85e684e 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -149,7 +149,7 @@ def default_dtypes(self, *, device=None): """ default_floating = torch.get_default_dtype() default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 - default_integral = torch.asarray(0, device=device).dtype + default_integral = torch.int64 return { "real floating": default_floating, "complex floating": default_complex, From ab822f59c4580899abb447c4c1cdf99a54da8dfb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 14:28:40 -0600 Subject: [PATCH 19/27] Ignore bare except in ruff checks --- ruff.toml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ruff.toml b/ruff.toml index 53e8596c..72e111b5 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,5 +9,9 @@ select = [ "PLC0414" ] -# Ignore module import not at top of file -ignore = ["E402"] +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] From cb9acd4073e6772c77465261db965cd02dd58244 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 14:28:54 -0600 Subject: [PATCH 20/27] Add a comment --- array_api_compat/torch/_info.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index a85e684e..264caa9e 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -147,6 +147,10 @@ def default_dtypes(self, *, device=None): 'indexing': torch.int64} """ + # Note: if the default is set to float64, the devices like MPS that + # don't support float64 will error. We still return the default_dtype + # value here because this error doesn't represent a different default + # per-device. default_floating = torch.get_default_dtype() default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 default_integral = torch.int64 From c0dd5b0e2e09df342d392b7ca4404cbf7c34af8f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 14:34:55 -0600 Subject: [PATCH 21/27] Add cumulative_sum to torch --- array_api_compat/common/_aliases.py | 4 +++- array_api_compat/torch/_aliases.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index a3437526..91c4d9a7 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -277,6 +277,8 @@ def cumulative_sum( include_initial: bool = False, **kwargs ) -> ndarray: + wrapped_xp = array_namespace(x) + # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: @@ -290,7 +292,7 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [xp.zeros_like(res, shape=initial_shape), res], + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], axis=axis, ) return res diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 1c1ce26b..10725253 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,8 +4,11 @@ from builtins import all as _builtin_all, any as _builtin_any from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot, clip as - _aliases_clip, unstack as _aliases_unstack,) + vecdot as _aliases_vecdot, + clip as _aliases_clip, + unstack as _aliases_unstack, + cumulative_sum as _aliases_cumulative_sum, + ) from .._internal import get_xp from ._info import __array_namespace_info__ @@ -198,6 +201,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep clip = get_xp(torch)(_aliases_clip) unstack = get_xp(torch)(_aliases_unstack) +cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -732,11 +736,11 @@ def sign(x: array, /) -> array: 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort', - 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', - 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', - 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', - 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', + 'cumulative_sum', 'sort', 'prod', 'sum', 'any', 'all', 'mean', + 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', + 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', + 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', From 1b47d96c21c5aaaa1724c14006456fa9b74d5def Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:07:56 -0600 Subject: [PATCH 22/27] Fix the tests --- tests/test_all.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_all.py b/tests/test_all.py index ff01fbdd..969d5cfb 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -32,6 +32,8 @@ def test_all(library): continue dir_names = [n for n in dir(module) if not n.startswith('_')] + if '__array_namespace_info__' in dir(module): + dir_names.append('__array_namespace_info__') ignore_all_names = getattr(module, '_all_ignore', []) ignore_all_names += ['annotations', 'TYPE_CHECKING'] dir_names = set(dir_names) - set(ignore_all_names) From e472dcb362c7e72a13ddeedbe4e2b76e205ad3fa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:13:01 -0600 Subject: [PATCH 23/27] Add repeat to the torch xfails --- torch-xfails.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index aedbc4a7..3cce6698 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -189,3 +189,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values + +# 2023.12 support +array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +array_api_tests/test_manipulation_functions.py::test_repeat +array_api_tests/test_signatures.py::test_func_signature[repeat] From 470e41a5388bac73523eb204f78f0b9dcd6ddd08 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:15:26 -0600 Subject: [PATCH 24/27] Update torch xfails --- torch-xfails.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index 3cce6698..c7abe2e9 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -194,3 +194,9 @@ array_api_tests/test_set_functions.py::test_unique_values array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[repeat] +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +# Argument 'max_version' missing from signature +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[astype] From 198d6d732e0fc1d8cd6e3bd0c24c6282bd6c56c7 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:17:03 -0600 Subject: [PATCH 25/27] Update numpy dev xfails --- numpy-dev-xfails.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index f3239dbf..34d747e3 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -17,3 +17,7 @@ array_api_tests/test_statistical_functions.py::test_cumulative_sum # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# 2023.12 support +# Argument 'device' missing from signature +array_api_tests/test_signatures.py::test_func_signature[astype] From ef7ad7ab80c6a0df8288253bd0129d23597e74ed Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:30:20 -0600 Subject: [PATCH 26/27] Add maximum and minimum torch wrappers (for fixed two-arg type promotion) --- array_api_compat/torch/_aliases.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 10725253..5ac66bcb 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -177,6 +177,8 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: logaddexp = _two_arg(torch.logaddexp) # logical functions are not included here because they only accept bool in the # spec, so type promotion is irrelevant. +maximum = _two_arg(torch.maximum) +minimum = _two_arg(torch.minimum) multiply = _two_arg(torch.multiply) not_equal = _two_arg(torch.not_equal) pow = _two_arg(torch.pow) @@ -735,15 +737,16 @@ def sign(x: array, /) -> array: 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', - 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', - 'cumulative_sum', 'sort', 'prod', 'sum', 'any', 'all', 'mean', - 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', - 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take', 'sign'] + 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', + 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', + 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum', + 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', + 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', + 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', + 'take', 'sign'] _all_ignore = ['torch', 'get_xp'] From 9d2e283401a12def1e749f2fb496d15fde49aba8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:39:42 -0600 Subject: [PATCH 27/27] Add dlpack xfails to numpy-dev --- numpy-dev-xfails.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 34d747e3..68f954d2 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -21,3 +21,5 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[astype] +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]