diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml new file mode 100644 index 00000000..0b516ddb --- /dev/null +++ b/.github/workflows/array-api-tests-dask.yml @@ -0,0 +1,9 @@ +name: Array API Tests (Dask) + +on: [push, pull_request] + +jobs: + array-api-tests-dask: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: dask diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index e5231770..41da4d64 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -303,6 +303,8 @@ def _asarray( import numpy as xp elif namespace == 'cupy': import cupy as xp + elif namespace == 'dask': + import dask.array as xp else: raise ValueError("Unrecognized namespace argument to asarray()") @@ -322,7 +324,9 @@ def _asarray( if copy in COPY_FALSE: # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, xp.ndarray): + # TODO: This feels wrong (__array__ is not in the standard) + # Dask doesn't support DLPack, though, so, this'll do + if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): if dtype is not None and obj.dtype != dtype: copy = True if copy in COPY_TRUE: diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 61aee4be..94cbb902 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -40,6 +40,16 @@ def _is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) +def _is_dask_array(x): + # Avoid importing dask if it isn't already + if 'dask.array' not in sys.modules: + return False + + import dask.array + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, dask.array.Array) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. @@ -97,6 +107,13 @@ def your_function(x, y): else: import torch namespaces.add(torch) + elif _is_dask_array(x): + _check_api_version(api_version) + if _use_compat: + from .. import dask as dask_namespace + namespaces.add(dask_namespace) + else: + raise TypeError("_use_compat cannot be False if input array is a dask array!") else: # TODO: Support Python scalars? raise TypeError("The input is not a supported array type") @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A return _cupy_to_device(x, device, stream=stream) elif _is_torch_array(x): return _torch_to_device(x, device, stream=stream) + elif _is_dask_array(x): + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + # TODO: What if our array is on the GPU already? + if device == 'cpu': + return x + raise ValueError(f"Unsupported device {device!r}") return x.to_device(device, stream=stream) def size(x): diff --git a/array_api_compat/dask/__init__.py b/array_api_compat/dask/__init__.py new file mode 100644 index 00000000..12e3a92d --- /dev/null +++ b/array_api_compat/dask/__init__.py @@ -0,0 +1,6 @@ +from dask.array import * + +# These imports may overwrite names from the import * above. +from ._aliases import * + +__array_api_version__ = '2022.12' diff --git a/array_api_compat/dask/_aliases.py b/array_api_compat/dask/_aliases.py new file mode 100644 index 00000000..93fe4594 --- /dev/null +++ b/array_api_compat/dask/_aliases.py @@ -0,0 +1,88 @@ +from ..common import _aliases + +from .._internal import get_xp + +import numpy as np +from numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool_ as bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, +) + +import dask.array as da + +isdtype = get_xp(np)(_aliases.isdtype) +astype = _aliases.astype + +# Common aliases +arange = get_xp(da)(_aliases.arange) + +from functools import partial +asarray = partial(_aliases._asarray, namespace='dask') +asarray.__doc__ = _aliases._asarray.__doc__ + +linspace = get_xp(da)(_aliases.linspace) +eye = get_xp(da)(_aliases.eye) +UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) +unique_all = get_xp(da)(_aliases.unique_all) +unique_counts = get_xp(da)(_aliases.unique_counts) +unique_inverse = get_xp(da)(_aliases.unique_inverse) +unique_values = get_xp(da)(_aliases.unique_values) +permute_dims = get_xp(da)(_aliases.permute_dims) +std = get_xp(da)(_aliases.std) +var = get_xp(da)(_aliases.var) +empty = get_xp(da)(_aliases.empty) +empty_like = get_xp(da)(_aliases.empty_like) +full = get_xp(da)(_aliases.full) +full_like = get_xp(da)(_aliases.full_like) +ones = get_xp(da)(_aliases.ones) +ones_like = get_xp(da)(_aliases.ones_like) +zeros = get_xp(da)(_aliases.zeros) +zeros_like = get_xp(da)(_aliases.zeros_like) +reshape = get_xp(da)(_aliases.reshape) +matrix_transpose = get_xp(da)(_aliases.matrix_transpose) +vecdot = get_xp(da)(_aliases.vecdot) + + + +from dask.array import ( + # Element wise aliases + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + left_shift as bitwise_left_shift, + right_shift as bitwise_right_shift, + invert as bitwise_invert, + power as pow, + # Other + concatenate as concat, + +) + diff --git a/dask-skips.txt b/dask-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/dask-xfails.txt b/dask-xfails.txt new file mode 100644 index 00000000..cc8b39d0 --- /dev/null +++ b/dask-xfails.txt @@ -0,0 +1,47 @@ +# finfo(float32).eps returns float32 but should return float +array_api_tests/test_data_type_functions.py::test_finfo[float32] + +# No sorting in dask +array_api_tests/test_has_names.py::test_has_names[sorting-argsort] +array_api_tests/test_has_names.py::test_has_names[sorting-sort] + +# Array methods and attributes not already on np.ndarray cannot be wrapped +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] + +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] - AssertionError: out[0]=0, but should be (x1 + x2[0])=65536 [__add__()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_and(ctx=BinaryParamContext(<__and__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the ... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_right_shift(ctx=BinaryParamContext(<__rshift__(x1, x2)>), data=data(...)) produces unreliable results: Falsif... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 994.44ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] - ValueError: Inferred dtype from function 'xor' was 'uint64' but got 'int16', which can't be cast using casting='same_kind' +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_ceil - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_divide(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified on the first ... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1015.43ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 900.99ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1106.49ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater(ctx=BinaryParamContext(<__gt__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<__ge__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 961.84ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 980.68ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1043.47ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1011.76ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] - AssertionError: out[0]=0, but should be (x1 * x2[0])=256 [multiply()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] - AssertionError: out[0]=2, but should be (x1 * x2[0])=258 [__mul__()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1034.64ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) + +#FAILED array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - Failed: DID NOT RAISE + +# Fails because shape is NaN since we don't materialize it yet +#FAILED array_api_tests/test_searching_functions.py::test_nonzero - AssertionError: prod(out[0].shape)=nan, but should be prod(out[0].shape)=nan +#FAILED array_api_tests/test_set_functions.py::test_unique_all - AssertionError: out.indices.shape=(nan,), but should be out.values.shape=(nan,) +#FAILED array_api_tests/test_set_functions.py::test_unique_counts - AssertionError: out.counts.shape=(nan,), but should be out.values.shape=(nan,) + +# Needs investigation +#FAILED array_api_tests/test_set_functions.py::test_unique_inverse - TypeError: 'float' object cannot be interpreted as an integer +#FAILED array_api_tests/test_set_functions.py::test_unique_values - TypeError: 'float' object cannot be interpreted as an integer diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1675377d..4012f212 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -5,8 +5,7 @@ import pytest - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("api_version", [None, '2021.12']) def test_array_namespace(library, api_version): lib = import_(library) @@ -17,6 +16,8 @@ def test_array_namespace(library, api_version): if 'array_api' in library: assert namespace == lib else: + if library == "dask.array": + library = "dask" assert namespace == getattr(array_api_compat, library) diff --git a/tests/test_common.py b/tests/test_common.py index 86886b7f..45d8030a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_allclose -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index d164699e..6a40140d 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind): assert type(res) is bool return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) def test_isdtype_spec_dtypes(library): xp = import_('array_api_compat.' + library) @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): xp = import_('array_api_compat.' + library) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 93a961aa..44b89a74 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -17,3 +17,7 @@ def test_vendoring_cupy(): def test_vendoring_torch(): from vendor_test import uses_torch uses_torch._test_torch() + +def test_vendoring_torch(): + from vendor_test import uses_torch + uses_torch._test_torch() diff --git a/vendor_test/uses_dask.py b/vendor_test/uses_dask.py new file mode 100644 index 00000000..d376bdf1 --- /dev/null +++ b/vendor_test/uses_dask.py @@ -0,0 +1,19 @@ +# Basic test that vendoring works + +from .vendored._compat import dask as dask_compat + +import dask.array as da +import numpy as np + +def _test_numpy(): + a = dask_compat.asarray([1., 2., 3.]) + b = dask_compat.arange(3, dtype=dask_compat.float32) + + # np.pow does not exist. Update this to use something else if it is added + res = dask_compat.pow(a, b) + assert res.dtype == dask_compat.float64 == np.float64 + assert isinstance(a, da.array) + assert isinstance(b, da.array) + assert isinstance(res, da.array) + + np.testing.assert_allclose(res, [1., 2., 9.])