Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask to array-api-compat #76

Merged
merged 25 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Array API Tests (Dask)

on: [push, pull_request]

jobs:
array-api-tests-dask:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
module-name: dask.array
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch
python -m pip install pytest numpy torch dask[array]

- name: Run Tests
run: |
Expand Down
22 changes: 20 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

Expand All @@ -322,11 +324,27 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
#print('hit me')
if dtype is not None and obj.dtype != dtype:
copy = True
#print(copy)
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
copy_kwargs = {}
if namespace != "dask.array":
copy_kwargs["copy"] = True
else:
# No copy kw in dask.asarray so we go thorugh np.asarray first
# (like dask also does) but copy after
if dtype is None:
# Same dtype copy is no-op in dask
#print("in here?")
return obj.copy()
import numpy as np
#print(obj)
obj = np.asarray(obj).copy()
#print(obj)
return xp.array(obj, dtype=dtype, **copy_kwargs)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False

import dask.array

return isinstance(x, dask.array.Array)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -95,6 +105,13 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
Expand Down Expand Up @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask supports cupy on the GPU.

Is this something we also need to take into consideration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, it does support CuPy. I"m not sure that a stream input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.

That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole to_device method doesn't quite work for Dask? @jakirkham any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy DaskDevice object that serves as the device for dask arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithomas1 what's the status of this comment? Does the device() helper above need to be updated for dask?

if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)

def size(x):
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def matrix_rank(x: ndarray,
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
S = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
Expand Down
8 changes: 8 additions & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dask.array import *

# These imports may overwrite names from the import * above.
from ._aliases import *

__array_api_version__ = '2022.12'

__import__(__package__ + '.linalg')
145 changes: 145 additions & 0 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import ndarray, Device, Dtype

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype

# Common aliases

# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask

# TODO: delete the xp stuff, it shouldn't be necessary
def dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

nonzero = get_xp(da)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)

from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

# exclude these from all since
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']

del da, partial, common_aliases, _da_unsupported,
48 changes: 48 additions & 0 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from dask.array.linalg import *
from ...common import _linalg
from ..._internal import get_xp
from dask.array import matmul, tensordot, trace, outer
from ._aliases import matrix_transpose, vecdot

import dask.array as da

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Tuple
from ...common._typing import ndarray

# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)

def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s

vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
"svdvals", "vector_norm", "diagonal"]

del get_xp
del da
del _linalg
5 changes: 5 additions & 0 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# FFT isn't conformant
array_api_tests/test_fft.py

# slow and not implemented in dask
array_api_tests/test_linalg.py::test_matrix_power
Loading