Skip to content

Commit

Permalink
Merge pull request #76 from lithomas1/add-dask
Browse files Browse the repository at this point in the history
Add dask to array-api-compat
  • Loading branch information
asmeurer authored Feb 6, 2024
2 parents 916a84b + 54f4838 commit 9cb5a13
Show file tree
Hide file tree
Showing 15 changed files with 427 additions and 10 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Array API Tests (Dask)

on: [push, pull_request]

jobs:
array-api-tests-dask:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
module-name: dask.array
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch
python -m pip install pytest numpy torch dask[array]
- name: Run Tests
run: |
Expand Down
22 changes: 20 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

Expand All @@ -322,11 +324,27 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
#print('hit me')
if dtype is not None and obj.dtype != dtype:
copy = True
#print(copy)
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
copy_kwargs = {}
if namespace != "dask.array":
copy_kwargs["copy"] = True
else:
# No copy kw in dask.asarray so we go thorugh np.asarray first
# (like dask also does) but copy after
if dtype is None:
# Same dtype copy is no-op in dask
#print("in here?")
return obj.copy()
import numpy as np
#print(obj)
obj = np.asarray(obj).copy()
#print(obj)
return xp.array(obj, dtype=dtype, **copy_kwargs)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False

import dask.array

return isinstance(x, dask.array.Array)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -95,6 +105,13 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
Expand Down Expand Up @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)

def size(x):
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def matrix_rank(x: ndarray,
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
S = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
Expand Down
8 changes: 8 additions & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dask.array import *

# These imports may overwrite names from the import * above.
from ._aliases import *

__array_api_version__ = '2022.12'

__import__(__package__ + '.linalg')
145 changes: 145 additions & 0 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import ndarray, Device, Dtype

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype

# Common aliases

# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask

# TODO: delete the xp stuff, it shouldn't be necessary
def dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

nonzero = get_xp(da)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)

from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

# exclude these from all since
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']

del da, partial, common_aliases, _da_unsupported,
48 changes: 48 additions & 0 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from dask.array.linalg import *
from ...common import _linalg
from ..._internal import get_xp
from dask.array import matmul, tensordot, trace, outer
from ._aliases import matrix_transpose, vecdot

import dask.array as da

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Tuple
from ...common._typing import ndarray

# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)

def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s

vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
"svdvals", "vector_norm", "diagonal"]

del get_xp
del da
del _linalg
5 changes: 5 additions & 0 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# FFT isn't conformant
array_api_tests/test_fft.py

# slow and not implemented in dask
array_api_tests/test_linalg.py::test_matrix_power
Loading

0 comments on commit 9cb5a13

Please sign in to comment.