diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml new file mode 100644 index 00000000..b0ce007e --- /dev/null +++ b/.github/workflows/array-api-tests-dask.yml @@ -0,0 +1,12 @@ +name: Array API Tests (Dask) + +on: [push, pull_request] + +jobs: + array-api-tests-dask: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: dask + module-name: dask.array + extra-requires: numpy + pytest-extra-args: --disable-deadline --max-examples=5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8610b6f0..71083fbc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch + python -m pip install pytest numpy torch dask[array] - name: Run Tests run: | diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index c057e71d..7713213e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -303,6 +303,8 @@ def _asarray( import numpy as xp elif namespace == 'cupy': import cupy as xp + elif namespace == 'dask.array': + import dask.array as xp else: raise ValueError("Unrecognized namespace argument to asarray()") @@ -322,11 +324,27 @@ def _asarray( if copy in COPY_FALSE: # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, xp.ndarray): + if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): + #print('hit me') if dtype is not None and obj.dtype != dtype: copy = True + #print(copy) if copy in COPY_TRUE: - return xp.array(obj, copy=True, dtype=dtype) + copy_kwargs = {} + if namespace != "dask.array": + copy_kwargs["copy"] = True + else: + # No copy kw in dask.asarray so we go thorugh np.asarray first + # (like dask also does) but copy after + if dtype is None: + # Same dtype copy is no-op in dask + #print("in here?") + return obj.copy() + import numpy as np + #print(obj) + obj = np.asarray(obj).copy() + #print(obj) + return xp.array(obj, dtype=dtype, **copy_kwargs) return obj return xp.asarray(obj, dtype=dtype, **kwargs) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index c1b0aef3..82bf47c1 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -40,6 +40,15 @@ def _is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) +def _is_dask_array(x): + # Avoid importing dask if it isn't already + if 'dask.array' not in sys.modules: + return False + + import dask.array + + return isinstance(x, dask.array.Array) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. @@ -47,6 +56,7 @@ def is_array_api_obj(x): return _is_numpy_array(x) \ or _is_cupy_array(x) \ or _is_torch_array(x) \ + or _is_dask_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -95,6 +105,13 @@ def your_function(x, y): else: import torch namespaces.add(torch) + elif _is_dask_array(x): + _check_api_version(api_version) + if _use_compat: + from ..dask import array as dask_namespace + namespaces.add(dask_namespace) + else: + raise TypeError("_use_compat cannot be False if input array is a dask array!") elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) else: @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A return _cupy_to_device(x, device, stream=stream) elif _is_torch_array(x): return _torch_to_device(x, device, stream=stream) + elif _is_dask_array(x): + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + # TODO: What if our array is on the GPU already? + if device == 'cpu': + return x + raise ValueError(f"Unsupported device {device!r}") return x.to_device(device, stream=stream) def size(x): diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index ce5b55d1..9f0c993f 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -77,7 +77,7 @@ def matrix_rank(x: ndarray, # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = xp.linalg.svd(x, compute_uv=False, **kwargs) + S = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py new file mode 100644 index 00000000..a7c0b22e --- /dev/null +++ b/array_api_compat/dask/array/__init__.py @@ -0,0 +1,8 @@ +from dask.array import * + +# These imports may overwrite names from the import * above. +from ._aliases import * + +__array_api_version__ = '2022.12' + +__import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py new file mode 100644 index 00000000..ef9ea356 --- /dev/null +++ b/array_api_compat/dask/array/_aliases.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from ...common import _aliases +from ...common._helpers import _check_device + +from ..._internal import get_xp + +import numpy as np +from numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool_ as bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, +) + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Union + from ...common._typing import ndarray, Device, Dtype + +import dask.array as da + +isdtype = get_xp(np)(_aliases.isdtype) +astype = _aliases.astype + +# Common aliases + +# This arange func is modified from the common one to +# not pass stop/step as keyword arguments, which will cause +# an error with dask + +# TODO: delete the xp stuff, it shouldn't be necessary +def dask_arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + xp, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs +) -> ndarray: + _check_device(xp, device) + args = [start] + if stop is not None: + args.append(stop) + else: + # stop is None, so start is actually stop + # prepend the default value for start which is 0 + args.insert(0, 0) + args.append(step) + return xp.arange(*args, dtype=dtype, **kwargs) + +arange = get_xp(da)(dask_arange) +eye = get_xp(da)(_aliases.eye) + +from functools import partial +asarray = partial(_aliases._asarray, namespace='dask.array') +asarray.__doc__ = _aliases._asarray.__doc__ + +linspace = get_xp(da)(_aliases.linspace) +eye = get_xp(da)(_aliases.eye) +UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) +unique_all = get_xp(da)(_aliases.unique_all) +unique_counts = get_xp(da)(_aliases.unique_counts) +unique_inverse = get_xp(da)(_aliases.unique_inverse) +unique_values = get_xp(da)(_aliases.unique_values) +permute_dims = get_xp(da)(_aliases.permute_dims) +std = get_xp(da)(_aliases.std) +var = get_xp(da)(_aliases.var) +empty = get_xp(da)(_aliases.empty) +empty_like = get_xp(da)(_aliases.empty_like) +full = get_xp(da)(_aliases.full) +full_like = get_xp(da)(_aliases.full_like) +ones = get_xp(da)(_aliases.ones) +ones_like = get_xp(da)(_aliases.ones_like) +zeros = get_xp(da)(_aliases.zeros) +zeros_like = get_xp(da)(_aliases.zeros_like) +reshape = get_xp(da)(_aliases.reshape) +matrix_transpose = get_xp(da)(_aliases.matrix_transpose) +vecdot = get_xp(da)(_aliases.vecdot) + +nonzero = get_xp(da)(_aliases.nonzero) +sum = get_xp(np)(_aliases.sum) +prod = get_xp(np)(_aliases.prod) +ceil = get_xp(np)(_aliases.ceil) +floor = get_xp(np)(_aliases.floor) +trunc = get_xp(np)(_aliases.trunc) +matmul = get_xp(np)(_aliases.matmul) +tensordot = get_xp(np)(_aliases.tensordot) + +from dask.array import ( + # Element wise aliases + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + left_shift as bitwise_left_shift, + right_shift as bitwise_right_shift, + invert as bitwise_invert, + power as pow, + # Other + concatenate as concat, +) + +# exclude these from all since +_da_unsupported = ['sort', 'argsort'] + +common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] + +__all__ = common_aliases + ['asarray', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow', + 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + +del da, partial, common_aliases, _da_unsupported, diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py new file mode 100644 index 00000000..c8aa7c9f --- /dev/null +++ b/array_api_compat/dask/array/linalg.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from dask.array.linalg import * +from ...common import _linalg +from ..._internal import get_xp +from dask.array import matmul, tensordot, trace, outer +from ._aliases import matrix_transpose, vecdot + +import dask.array as da + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Union, Tuple + from ...common._typing import ndarray + +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from dask.array.linalg import *', _n) +del _n['__builtins__'] +linalg_all = list(_n) +del _n + +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +qr = get_xp(da)(_linalg.qr) +cholesky = get_xp(da)(_linalg.cholesky) +matrix_rank = get_xp(da)(_linalg.matrix_rank) +matrix_norm = get_xp(da)(_linalg.matrix_norm) + +def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + # TODO: can't avoid computing U or V for dask + _, s, _ = svd(x) + return s + +vector_norm = get_xp(da)(_linalg.vector_norm) +diagonal = get_xp(da)(_linalg.diagonal) + +__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", + "SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", + "svdvals", "vector_norm", "diagonal"] + +del get_xp +del da +del _linalg diff --git a/dask-skips.txt b/dask-skips.txt new file mode 100644 index 00000000..8e884ac0 --- /dev/null +++ b/dask-skips.txt @@ -0,0 +1,5 @@ +# FFT isn't conformant +array_api_tests/test_fft.py + +# slow and not implemented in dask +array_api_tests/test_linalg.py::test_matrix_power diff --git a/dask-xfails.txt b/dask-xfails.txt new file mode 100644 index 00000000..39a1dd8a --- /dev/null +++ b/dask-xfails.txt @@ -0,0 +1,132 @@ +# This fails in dask +# import dask.array as da +# a = da.array([1]).reshape((1,1)) +# key = (0, slice(None, None, -1)) +# a[key] = da.array([1]) + +# Failing hypothesis test case +#x=dask.array +#| Draw 1 (key): (slice(None, None, None), slice(None, None, None)) +#| Draw 2 (value): dask.array + +# Various shape mismatches e.g. +ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2) +array_api_tests/test_array_object.py::test_setitem + +# Fails since bad upcast from uint8 -> int64 +# MRE: +# a = da.array(0, dtype="uint8") +# b = da.array(False) +# a[b] = 0 +array_api_tests/test_array_object.py::test_setitem_masking + +# Various indexing errors +array_api_tests/test_array_object.py::test_getitem_masking + +# asarray(copy=False) is not yet implemented +# copied from numpy xfails, TODO: should this pass with dask? +array_api_tests/test_creation_functions.py::test_asarray_arrays + +# zero division error, and typeerror: tuple indices must be integers or slices not tuple +array_api_tests/test_creation_functions.py::test_eye + +# finfo(float32).eps returns float32 but should return float +array_api_tests/test_data_type_functions.py::test_finfo[float32] + +# out[-1]=dask.aray but should be some floating number +# (I think the test is not forcing the op to be computed?) +array_api_tests/test_creation_functions.py::test_linspace + +# out=-0, but should be +0 +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] + +# output is nan but should be infinity +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] + +# No sorting in dask +array_api_tests/test_has_names.py::test_has_names[sorting-argsort] +array_api_tests/test_has_names.py::test_has_names[sorting-sort] +array_api_tests/test_sorting_functions.py::test_argsort +array_api_tests/test_sorting_functions.py::test_sort +array_api_tests/test_signatures.py::test_func_signature[argsort] +array_api_tests/test_signatures.py::test_func_signature[sort] + +# Array methods and attributes not already on np.ndarray cannot be wrapped +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] + +# Fails because shape is NaN since we don't materialize it yet +array_api_tests/test_searching_functions.py::test_nonzero +array_api_tests/test_set_functions.py::test_unique_all +array_api_tests/test_set_functions.py::test_unique_counts + +# Different error but same cause as above, we're just trying to do ndindex on nan shape +array_api_tests/test_set_functions.py::test_unique_inverse +array_api_tests/test_set_functions.py::test_unique_values + +# Linalg failures (signature failures/missing methods) + +# fails for ndim > 2 +array_api_tests/test_linalg.py::test_svdvals +array_api_tests/test_linalg.py::test_cholesky +# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( +array_api_tests/test_linalg.py::test_tensordot +# probably same reason for failing as numpy +array_api_tests/test_linalg.py::test_trace + +# Linalg - these don't exist in dask +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] +array_api_tests/test_linalg.py::test_cross +array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_pinv +array_api_tests/test_linalg.py::test_slogdet +array_api_tests/test_has_names.py::test_has_names[linalg-cross] +array_api_tests/test_has_names.py::test_has_names[linalg-det] +array_api_tests/test_has_names.py::test_has_names[linalg-eigh] +array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] +array_api_tests/test_has_names.py::test_has_names[linalg-pinv] +array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] + +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_matrix_rank + +# missing mode kw +# https://github.com/dask/dask/issues/10388 +array_api_tests/test_linalg.py::test_qr + +# Constructing the input arrays fails to a weird shape error... +array_api_tests/test_linalg.py::test_solve + +# missing full_matrics kw +# https://github.com/dask/dask/issues/10389 +# also only supports 2-d inputs +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] +array_api_tests/test_linalg.py::test_svd + +# Missing dlpack stuff +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] +array_api_tests/test_signatures.py::test_array_method_signature[to_device] +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] + +# Some cases unsupported by dask +array_api_tests/test_manipulation_functions.py::test_roll + +# No mT on dask array +array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1675377d..0becfc3d 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -5,8 +5,7 @@ import pytest - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("api_version", [None, '2021.12']) def test_array_namespace(library, api_version): lib = import_(library) @@ -17,7 +16,10 @@ def test_array_namespace(library, api_version): if 'array_api' in library: assert namespace == lib else: - assert namespace == getattr(array_api_compat, library) + if library == "dask.array": + assert namespace == array_api_compat.dask.array + else: + assert namespace == getattr(array_api_compat, library) def test_array_namespace_errors(): diff --git a/tests/test_common.py b/tests/test_common.py index 86886b7f..f98a717a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_allclose -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index d164699e..77e7ce72 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind): assert type(res) is bool return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_isdtype_spec_dtypes(library): xp = import_('array_api_compat.' + library) @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): xp = import_('array_api_compat.' + library) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 93a961aa..873b233a 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -17,3 +17,7 @@ def test_vendoring_cupy(): def test_vendoring_torch(): from vendor_test import uses_torch uses_torch._test_torch() + +def test_vendoring_dask(): + from vendor_test import uses_dask + uses_dask._test_dask() diff --git a/vendor_test/uses_dask.py b/vendor_test/uses_dask.py new file mode 100644 index 00000000..65a00916 --- /dev/null +++ b/vendor_test/uses_dask.py @@ -0,0 +1,19 @@ +# Basic test that vendoring works + +from .vendored._compat.dask import array as dask_compat + +import dask.array as da +import numpy as np + +def _test_dask(): + a = dask_compat.asarray([1., 2., 3.]) + b = dask_compat.arange(3, dtype=dask_compat.float32) + + # np.pow does not exist. Update this to use something else if it is added + res = dask_compat.pow(a, b) + assert res.dtype == dask_compat.float64 == np.float64 + assert isinstance(a, da.Array) + assert isinstance(b, da.Array) + assert isinstance(res, da.Array) + + np.testing.assert_allclose(res, [1., 2., 9.])