-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dask to array-api-compat #76
Changes from 11 commits
8ee1613
3cc2857
ad6bf56
f178e28
8c9c784
6305d7e
06afd75
6e0ef29
9abe56d
20622ba
a484d5a
0d66160
df69086
1f9799a
424f25d
64d20c4
762a03c
69cc93b
f52b3d5
5edc5ec
565666a
6841758
4be5517
cd381a0
54f4838
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
name: Array API Tests (Dask) | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
array-api-tests-dask: | ||
uses: ./.github/workflows/array-api-tests.yml | ||
with: | ||
package-name: dask | ||
extra-requires: numpy | ||
pytest-extra-args: --disable-deadline --max-examples=5 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
A little CI shim for the dask backend that | ||
disables the dask scheduler | ||
|
||
It also lets you see the dask dashboard for debugging | ||
at http://127.0.0.1:8787/status | ||
""" | ||
import dask | ||
dask.config.set(scheduler='synchronous') | ||
|
||
from dask.distributed import Client | ||
_client = Client() | ||
|
||
from .dask import * | ||
from .dask import __array_api_version__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,6 +303,8 @@ def _asarray( | |
import numpy as xp | ||
elif namespace == 'cupy': | ||
import cupy as xp | ||
elif namespace == 'dask': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically, the array module of dask is called I guess we could match that here, but it feels kind of weird to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think (but am not 100% sure) that using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest waiting for @asmeurer to weigh in here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This keyword is only used internally so it doesn't really matter, but I like |
||
import dask.array as xp | ||
else: | ||
raise ValueError("Unrecognized namespace argument to asarray()") | ||
|
||
|
@@ -322,10 +324,13 @@ def _asarray( | |
if copy in COPY_FALSE: | ||
# copy=False is not yet implemented in xp.asarray | ||
raise NotImplementedError("copy=False is not yet implemented") | ||
if isinstance(obj, xp.ndarray): | ||
# TODO: This feels wrong (__array__ is not in the standard) | ||
# Dask doesn't support DLPack, though, so, this'll do | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is only used for numpy and cupy (and now dask). The internals of the wrappers can use specific functionality that isn't in the standard because they can assume that they are operating on the specific array library. So this isn't a problem. |
||
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): | ||
if dtype is not None and obj.dtype != dtype: | ||
copy = True | ||
if copy in COPY_TRUE: | ||
# Dask arrays are immutable, so copy doesn't do anything | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is probably fine but maybe this sort of thing should be clarified in the standard. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think what I was mostly worried about was whether I need to return a new dask array object (I guess this would kinda be like a view on the old object) or not. |
||
if copy in COPY_TRUE and namespace != "dask": | ||
return xp.array(obj, copy=True, dtype=dtype) | ||
return obj | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,13 +40,23 @@ def _is_torch_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, torch.Tensor) | ||
|
||
def _is_dask_array(x): | ||
# Avoid importing dask if it isn't already | ||
if 'dask.array' not in sys.modules: | ||
return False | ||
|
||
import dask.array | ||
|
||
return isinstance(x, dask.array.Array) | ||
|
||
def is_array_api_obj(x): | ||
""" | ||
Check if x is an array API compatible array object. | ||
""" | ||
return _is_numpy_array(x) \ | ||
or _is_cupy_array(x) \ | ||
or _is_torch_array(x) \ | ||
or _is_dask_array(x) \ | ||
or hasattr(x, '__array_namespace__') | ||
|
||
def _check_api_version(api_version): | ||
|
@@ -97,6 +107,13 @@ def your_function(x, y): | |
else: | ||
import torch | ||
namespaces.add(torch) | ||
elif _is_dask_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import dask as dask_namespace | ||
namespaces.add(dask_namespace) | ||
else: | ||
raise TypeError("_use_compat cannot be False if input array is a dask array!") | ||
else: | ||
# TODO: Support Python scalars? | ||
raise TypeError("The input is not a supported array type") | ||
|
@@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A | |
return _cupy_to_device(x, device, stream=stream) | ||
elif _is_torch_array(x): | ||
return _torch_to_device(x, device, stream=stream) | ||
elif _is_dask_array(x): | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
# TODO: What if our array is on the GPU already? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think dask supports cupy on the GPU. Is this something we also need to take into consideration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed, it does support CuPy. I"m not sure that a That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lithomas1 what's the status of this comment? Does the |
||
if device == 'cpu': | ||
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
return x.to_device(device, stream=stream) | ||
|
||
def size(x): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from dask.array import * | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
|
||
__array_api_version__ = '2022.12' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from __future__ import annotations | ||
|
||
from ..common import _aliases | ||
from ..common._helpers import _check_device | ||
|
||
from .._internal import get_xp | ||
|
||
import numpy as np | ||
from numpy import ( | ||
# Constants | ||
e, | ||
inf, | ||
nan, | ||
pi, | ||
newaxis, | ||
# Dtypes | ||
bool_ as bool, | ||
float32, | ||
float64, | ||
int8, | ||
int16, | ||
int32, | ||
int64, | ||
uint8, | ||
uint16, | ||
uint32, | ||
uint64, | ||
complex64, | ||
complex128, | ||
iinfo, | ||
finfo, | ||
can_cast, | ||
result_type, | ||
) | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Optional, Union | ||
from ..common._typing import ndarray, Device, Dtype | ||
|
||
import dask.array as da | ||
|
||
isdtype = get_xp(np)(_aliases.isdtype) | ||
astype = _aliases.astype | ||
|
||
# Common aliases | ||
|
||
# This arange func is modified from the common one to | ||
# not pass stop/step as keyword arguments, which will cause | ||
# an error with dask | ||
def dask_arange( | ||
start: Union[int, float], | ||
/, | ||
stop: Optional[Union[int, float]] = None, | ||
step: Union[int, float] = 1, | ||
*, | ||
xp, | ||
dtype: Optional[Dtype] = None, | ||
device: Optional[Device] = None, | ||
**kwargs | ||
) -> ndarray: | ||
_check_device(xp, device) | ||
args = [start] | ||
if stop is not None: | ||
args.append(stop) | ||
else: | ||
# stop is None, so start is actually stop | ||
# prepend the default value for start which is 0 | ||
args.insert(0, 0) | ||
args.append(step) | ||
return xp.arange(*args, dtype=dtype, **kwargs) | ||
|
||
arange = get_xp(da)(dask_arange) | ||
eye = get_xp(da)(_aliases.eye) | ||
|
||
from functools import partial | ||
asarray = partial(_aliases._asarray, namespace='dask') | ||
asarray.__doc__ = _aliases._asarray.__doc__ | ||
|
||
linspace = get_xp(da)(_aliases.linspace) | ||
eye = get_xp(da)(_aliases.eye) | ||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) | ||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) | ||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) | ||
unique_all = get_xp(da)(_aliases.unique_all) | ||
unique_counts = get_xp(da)(_aliases.unique_counts) | ||
unique_inverse = get_xp(da)(_aliases.unique_inverse) | ||
unique_values = get_xp(da)(_aliases.unique_values) | ||
permute_dims = get_xp(da)(_aliases.permute_dims) | ||
std = get_xp(da)(_aliases.std) | ||
var = get_xp(da)(_aliases.var) | ||
empty = get_xp(da)(_aliases.empty) | ||
empty_like = get_xp(da)(_aliases.empty_like) | ||
full = get_xp(da)(_aliases.full) | ||
full_like = get_xp(da)(_aliases.full_like) | ||
ones = get_xp(da)(_aliases.ones) | ||
ones_like = get_xp(da)(_aliases.ones_like) | ||
zeros = get_xp(da)(_aliases.zeros) | ||
zeros_like = get_xp(da)(_aliases.zeros_like) | ||
reshape = get_xp(da)(_aliases.reshape) | ||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose) | ||
vecdot = get_xp(da)(_aliases.vecdot) | ||
|
||
from dask.array import ( | ||
# Element wise aliases | ||
arccos as acos, | ||
arccosh as acosh, | ||
arcsin as asin, | ||
arcsinh as asinh, | ||
arctan as atan, | ||
arctan2 as atan2, | ||
arctanh as atanh, | ||
left_shift as bitwise_left_shift, | ||
right_shift as bitwise_right_shift, | ||
invert as bitwise_invert, | ||
power as pow, | ||
# Other | ||
concatenate as concat, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from dask.array.linalg import * | ||
from dask.array.linalg import __all__ as linalg_all | ||
|
||
from ..common import _linalg | ||
from .._internal import get_xp | ||
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) | ||
|
||
import dask.array as da | ||
|
||
cross = get_xp(da)(_linalg.cross) | ||
outer = get_xp(da)(_linalg.outer) | ||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
eigh = get_xp(da)(_linalg.eigh) | ||
qr = get_xp(da)(_linalg.qr) | ||
slogdet = get_xp(da)(_linalg.slogdet) | ||
svd = get_xp(da)(_linalg.svd) | ||
cholesky = get_xp(da)(_linalg.cholesky) | ||
matrix_rank = get_xp(da)(_linalg.matrix_rank) | ||
pinv = get_xp(da)(_linalg.pinv) | ||
matrix_norm = get_xp(da)(_linalg.matrix_norm) | ||
svdvals = get_xp(da)(_linalg.svdvals) | ||
vector_norm = get_xp(da)(_linalg.vector_norm) | ||
diagonal = get_xp(da)(_linalg.diagonal) | ||
trace = get_xp(da)(_linalg.trace) | ||
|
||
__all__ = linalg_all + _linalg.__all__ | ||
|
||
del get_xp | ||
del da | ||
del linalg_all | ||
del _linalg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# FFT isn't conformant | ||
array_api_tests/test_fft.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this is needed anymore (it didn't seem to help with the test that OOMed - which is now disabled).
Leaving this in just in case I haven't found the correct options.