Skip to content

Commit

Permalink
Add dask to array-api-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Dec 19, 2023
1 parent 874c2ff commit 8ee1613
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 6 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: Array API Tests (Dask)

on: [push, pull_request]

jobs:
array-api-tests-dask:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
6 changes: 5 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

Expand All @@ -322,7 +324,9 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
# TODO: This feels wrong (__array__ is not in the standard)
# Dask doesn't support DLPack, though, so, this'll do
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
Expand Down
24 changes: 24 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False

import dask.array

# TODO: Should we reject ndarray subclasses?
return isinstance(x, dask.array.Array)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
Expand Down Expand Up @@ -97,6 +107,13 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import dask as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
else:
# TODO: Support Python scalars?
raise TypeError("The input is not a supported array type")
Expand Down Expand Up @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)

def size(x):
Expand Down
6 changes: 6 additions & 0 deletions array_api_compat/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dask.array import *

# These imports may overwrite names from the import * above.
from ._aliases import *

__array_api_version__ = '2022.12'
88 changes: 88 additions & 0 deletions array_api_compat/dask/_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from ..common import _aliases

from .._internal import get_xp

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype

# Common aliases
arange = get_xp(da)(_aliases.arange)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)



from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,

)

Empty file added dask-skips.txt
Empty file.
47 changes: 47 additions & 0 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# No sorting in dask
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
array_api_tests/test_has_names.py::test_has_names[sorting-sort]

# Array methods and attributes not already on np.ndarray cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]

#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] - AssertionError: out[0]=0, but should be (x1 + x2[0])=65536 [__add__()]
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_and(ctx=BinaryParamContext(<__and__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the ...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_right_shift(ctx=BinaryParamContext(<__rshift__(x1, x2)>), data=data(...)) produces unreliable results: Falsif...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 994.44ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] - ValueError: Inferred dtype from function 'xor' was 'uint64' but got 'int16', which can't be cast using casting='same_kind'
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_ceil - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_divide(ctx=BinaryParamContext(<divide(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first ...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1015.43ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 900.99ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1106.49ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater(ctx=BinaryParamContext(<__gt__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<greater_equal(x1, x2)>), data=data(...)) produces unreliable results: Falsified...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<__ge__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the...
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 961.84ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 980.68ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1043.47ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1011.76ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] - AssertionError: out[0]=0, but should be (x1 * x2[0])=256 [multiply()]
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] - AssertionError: out[0]=2, but should be (x1 * x2[0])=258 [__mul__()]
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1034.64ms, which exceeds the deadline of 800.00ms
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)

#FAILED array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - Failed: DID NOT RAISE <class 'Exception'>

# Fails because shape is NaN since we don't materialize it yet
#FAILED array_api_tests/test_searching_functions.py::test_nonzero - AssertionError: prod(out[0].shape)=nan, but should be prod(out[0].shape)=nan
#FAILED array_api_tests/test_set_functions.py::test_unique_all - AssertionError: out.indices.shape=(nan,), but should be out.values.shape=(nan,)
#FAILED array_api_tests/test_set_functions.py::test_unique_counts - AssertionError: out.counts.shape=(nan,), but should be out.values.shape=(nan,)

# Needs investigation
#FAILED array_api_tests/test_set_functions.py::test_unique_inverse - TypeError: 'float' object cannot be interpreted as an integer
#FAILED array_api_tests/test_set_functions.py::test_unique_values - TypeError: 'float' object cannot be interpreted as an integer
5 changes: 3 additions & 2 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import pytest


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("api_version", [None, '2021.12'])
def test_array_namespace(library, api_version):
lib = import_(library)
Expand All @@ -17,6 +16,8 @@ def test_array_namespace(library, api_version):
if 'array_api' in library:
assert namespace == lib
else:
if library == "dask.array":
library = "dask"
assert namespace == getattr(array_api_compat, library)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from numpy.testing import assert_allclose

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
def test_to_device_host(library):
# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
Expand Down
4 changes: 2 additions & 2 deletions tests/test_isdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
assert type(res) is bool
return res

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
def test_isdtype_spec_dtypes(library):
xp = import_('array_api_compat.' + library)

Expand Down Expand Up @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library):
'bfloat16',
]

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
xp = import_('array_api_compat.' + library)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ def test_vendoring_cupy():
def test_vendoring_torch():
from vendor_test import uses_torch
uses_torch._test_torch()

def test_vendoring_torch():
from vendor_test import uses_torch
uses_torch._test_torch()
19 changes: 19 additions & 0 deletions vendor_test/uses_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Basic test that vendoring works

from .vendored._compat import dask as dask_compat

import dask.array as da
import numpy as np

def _test_numpy():
a = dask_compat.asarray([1., 2., 3.])
b = dask_compat.arange(3, dtype=dask_compat.float32)

# np.pow does not exist. Update this to use something else if it is added
res = dask_compat.pow(a, b)
assert res.dtype == dask_compat.float64 == np.float64
assert isinstance(a, da.array)
assert isinstance(b, da.array)
assert isinstance(res, da.array)

np.testing.assert_allclose(res, [1., 2., 9.])

0 comments on commit 8ee1613

Please sign in to comment.