Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask to array-api-compat #76

Merged
merged 25 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Array API Tests (Dask)

on: [push, pull_request]

jobs:
array-api-tests-dask:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
5 changes: 4 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
package-name:
required: true
type: string
extra-requires:
required: false
type: string
package-version:
required: false
type: string
Expand Down Expand Up @@ -54,7 +57,7 @@ jobs:
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
run: |
python -m pip install --upgrade pip
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}'
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the array API testsuite (${{ inputs.package-name }})
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch
python -m pip install pytest numpy torch dask[array]

- name: Run Tests
run: |
Expand Down
15 changes: 15 additions & 0 deletions array_api_compat/_dask_ci_shim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
A little CI shim for the dask backend that
disables the dask scheduler

It also lets you see the dask dashboard for debugging
at http://127.0.0.1:8787/status
"""
import dask
dask.config.set(scheduler='synchronous')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is needed anymore (it didn't seem to help with the test that OOMed - which is now disabled).

Leaving this in just in case I haven't found the correct options.


from dask.distributed import Client
_client = Client()

from .dask import *
from .dask import __array_api_version__
9 changes: 7 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, the array module of dask is called dask.array.

I guess we could match that here, but it feels kind of weird to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (but am not 100% sure) that using namespace == 'dask.array' here is preferred here, since it's meant to identify the actual namespace rather than the package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest waiting for @asmeurer to weigh in here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This keyword is only used internally so it doesn't really matter, but I like dask.array.

import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

Expand All @@ -322,10 +324,13 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
# TODO: This feels wrong (__array__ is not in the standard)
# Dask doesn't support DLPack, though, so, this'll do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used for numpy and cupy (and now dask). The internals of the wrappers can use specific functionality that isn't in the standard because they can assume that they are operating on the specific array library. So this isn't a problem.

if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
# Dask arrays are immutable, so copy doesn't do anything
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably fine but maybe this sort of thing should be clarified in the standard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think what I was mostly worried about was whether I need to return a new dask array object (I guess this would kinda be like a view on the old object) or not.

if copy in COPY_TRUE and namespace != "dask":
return xp.array(obj, copy=True, dtype=dtype)
return obj

Expand Down
24 changes: 24 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False

import dask.array

return isinstance(x, dask.array.Array)

def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand Down Expand Up @@ -97,6 +107,13 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import dask as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
else:
# TODO: Support Python scalars?
raise TypeError("The input is not a supported array type")
Expand Down Expand Up @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask supports cupy on the GPU.

Is this something we also need to take into consideration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, it does support CuPy. I"m not sure that a stream input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.

That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole to_device method doesn't quite work for Dask? @jakirkham any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy DaskDevice object that serves as the device for dask arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithomas1 what's the status of this comment? Does the device() helper above need to be updated for dask?

if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)

def size(x):
Expand Down
6 changes: 6 additions & 0 deletions array_api_compat/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dask.array import *

# These imports may overwrite names from the import * above.
from ._aliases import *

__array_api_version__ = '2022.12'
120 changes: 120 additions & 0 deletions array_api_compat/dask/_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from ..common import _aliases
from ..common._helpers import _check_device

from .._internal import get_xp

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ..common._typing import ndarray, Device, Dtype

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype

# Common aliases

# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
def dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

34 changes: 34 additions & 0 deletions array_api_compat/dask/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dask.array.linalg import *
from dask.array.linalg import __all__ as linalg_all

from ..common import _linalg
from .._internal import get_xp
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)

import dask.array as da

cross = get_xp(da)(_linalg.cross)
outer = get_xp(da)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(da)(_linalg.eigh)
qr = get_xp(da)(_linalg.qr)
slogdet = get_xp(da)(_linalg.slogdet)
svd = get_xp(da)(_linalg.svd)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
pinv = get_xp(da)(_linalg.pinv)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
svdvals = get_xp(da)(_linalg.svdvals)
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
trace = get_xp(da)(_linalg.trace)

__all__ = linalg_all + _linalg.__all__

del get_xp
del da
del linalg_all
del _linalg
2 changes: 2 additions & 0 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# FFT isn't conformant
array_api_tests/test_fft.py
Loading
Loading