-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dask to array-api-compat #76
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
8ee1613
Add dask to array-api-compat
lithomas1 3cc2857
fix workflows?
lithomas1 ad6bf56
typo
lithomas1 f178e28
green
lithomas1 8c9c784
go for green
lithomas1 6305d7e
try again
lithomas1 06afd75
fix arange test
lithomas1 6e0ef29
use max-examples=1
lithomas1 9abe56d
Update array-api-tests-dask.yml
lithomas1 20622ba
fix missing check
lithomas1 a484d5a
Merge branch 'add-dask' of github.com:lithomas1/array-api-compat into…
lithomas1 0d66160
change namespace
lithomas1 df69086
changes left behind
lithomas1 1f9799a
finish rename and fix linalg tests
lithomas1 424f25d
fix ci?
lithomas1 64d20c4
rename files again
lithomas1 762a03c
some fixes
lithomas1 69cc93b
fix astype bug
lithomas1 f52b3d5
Merge branch 'main' into add-dask
lithomas1 5edc5ec
remove file
lithomas1 565666a
Merge branch 'add-dask' of github.com:lithomas1/array-api-compat into…
lithomas1 6841758
address more comments
lithomas1 4be5517
fix more tests
lithomas1 cd381a0
address more feedback
lithomas1 54f4838
fix?
lithomas1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: Array API Tests (Dask) | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
array-api-tests-dask: | ||
uses: ./.github/workflows/array-api-tests.yml | ||
with: | ||
package-name: dask | ||
module-name: dask.array | ||
extra-requires: numpy | ||
pytest-extra-args: --disable-deadline --max-examples=5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dask.array import * | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
|
||
__array_api_version__ = '2022.12' | ||
|
||
__import__(__package__ + '.linalg') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
|
||
from ...common import _aliases | ||
from ...common._helpers import _check_device | ||
|
||
from ..._internal import get_xp | ||
|
||
import numpy as np | ||
from numpy import ( | ||
# Constants | ||
e, | ||
inf, | ||
nan, | ||
pi, | ||
newaxis, | ||
# Dtypes | ||
bool_ as bool, | ||
float32, | ||
float64, | ||
int8, | ||
int16, | ||
int32, | ||
int64, | ||
uint8, | ||
uint16, | ||
uint32, | ||
uint64, | ||
complex64, | ||
complex128, | ||
iinfo, | ||
finfo, | ||
can_cast, | ||
result_type, | ||
) | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Optional, Union | ||
from ...common._typing import ndarray, Device, Dtype | ||
|
||
import dask.array as da | ||
|
||
isdtype = get_xp(np)(_aliases.isdtype) | ||
astype = _aliases.astype | ||
|
||
# Common aliases | ||
|
||
# This arange func is modified from the common one to | ||
# not pass stop/step as keyword arguments, which will cause | ||
# an error with dask | ||
|
||
# TODO: delete the xp stuff, it shouldn't be necessary | ||
def dask_arange( | ||
start: Union[int, float], | ||
/, | ||
stop: Optional[Union[int, float]] = None, | ||
step: Union[int, float] = 1, | ||
*, | ||
xp, | ||
dtype: Optional[Dtype] = None, | ||
device: Optional[Device] = None, | ||
**kwargs | ||
) -> ndarray: | ||
_check_device(xp, device) | ||
args = [start] | ||
if stop is not None: | ||
args.append(stop) | ||
else: | ||
# stop is None, so start is actually stop | ||
# prepend the default value for start which is 0 | ||
args.insert(0, 0) | ||
args.append(step) | ||
return xp.arange(*args, dtype=dtype, **kwargs) | ||
|
||
arange = get_xp(da)(dask_arange) | ||
eye = get_xp(da)(_aliases.eye) | ||
|
||
from functools import partial | ||
asarray = partial(_aliases._asarray, namespace='dask.array') | ||
asarray.__doc__ = _aliases._asarray.__doc__ | ||
|
||
linspace = get_xp(da)(_aliases.linspace) | ||
eye = get_xp(da)(_aliases.eye) | ||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) | ||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) | ||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) | ||
unique_all = get_xp(da)(_aliases.unique_all) | ||
unique_counts = get_xp(da)(_aliases.unique_counts) | ||
unique_inverse = get_xp(da)(_aliases.unique_inverse) | ||
unique_values = get_xp(da)(_aliases.unique_values) | ||
permute_dims = get_xp(da)(_aliases.permute_dims) | ||
std = get_xp(da)(_aliases.std) | ||
var = get_xp(da)(_aliases.var) | ||
empty = get_xp(da)(_aliases.empty) | ||
empty_like = get_xp(da)(_aliases.empty_like) | ||
full = get_xp(da)(_aliases.full) | ||
full_like = get_xp(da)(_aliases.full_like) | ||
ones = get_xp(da)(_aliases.ones) | ||
ones_like = get_xp(da)(_aliases.ones_like) | ||
zeros = get_xp(da)(_aliases.zeros) | ||
zeros_like = get_xp(da)(_aliases.zeros_like) | ||
reshape = get_xp(da)(_aliases.reshape) | ||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose) | ||
vecdot = get_xp(da)(_aliases.vecdot) | ||
|
||
nonzero = get_xp(da)(_aliases.nonzero) | ||
sum = get_xp(np)(_aliases.sum) | ||
prod = get_xp(np)(_aliases.prod) | ||
ceil = get_xp(np)(_aliases.ceil) | ||
floor = get_xp(np)(_aliases.floor) | ||
trunc = get_xp(np)(_aliases.trunc) | ||
matmul = get_xp(np)(_aliases.matmul) | ||
tensordot = get_xp(np)(_aliases.tensordot) | ||
|
||
from dask.array import ( | ||
# Element wise aliases | ||
arccos as acos, | ||
arccosh as acosh, | ||
arcsin as asin, | ||
arcsinh as asinh, | ||
arctan as atan, | ||
arctan2 as atan2, | ||
arctanh as atanh, | ||
left_shift as bitwise_left_shift, | ||
right_shift as bitwise_right_shift, | ||
invert as bitwise_invert, | ||
power as pow, | ||
# Other | ||
concatenate as concat, | ||
) | ||
|
||
# exclude these from all since | ||
_da_unsupported = ['sort', 'argsort'] | ||
|
||
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] | ||
|
||
__all__ = common_aliases + ['asarray', 'bool', 'acos', | ||
'acosh', 'asin', 'asinh', 'atan', 'atan2', | ||
'atanh', 'bitwise_left_shift', 'bitwise_invert', | ||
'bitwise_right_shift', 'concat', 'pow', | ||
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', | ||
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', | ||
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] | ||
|
||
del da, partial, common_aliases, _da_unsupported, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from __future__ import annotations | ||
|
||
from dask.array.linalg import * | ||
from ...common import _linalg | ||
from ..._internal import get_xp | ||
from dask.array import matmul, tensordot, trace, outer | ||
from ._aliases import matrix_transpose, vecdot | ||
|
||
import dask.array as da | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Union, Tuple | ||
from ...common._typing import ndarray | ||
|
||
# cupy.linalg doesn't have __all__. If it is added, replace this with | ||
# | ||
# from cupy.linalg import __all__ as linalg_all | ||
_n = {} | ||
exec('from dask.array.linalg import *', _n) | ||
del _n['__builtins__'] | ||
linalg_all = list(_n) | ||
del _n | ||
|
||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
qr = get_xp(da)(_linalg.qr) | ||
cholesky = get_xp(da)(_linalg.cholesky) | ||
matrix_rank = get_xp(da)(_linalg.matrix_rank) | ||
matrix_norm = get_xp(da)(_linalg.matrix_norm) | ||
|
||
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: | ||
# TODO: can't avoid computing U or V for dask | ||
_, s, _ = svd(x) | ||
return s | ||
|
||
vector_norm = get_xp(da)(_linalg.vector_norm) | ||
diagonal = get_xp(da)(_linalg.diagonal) | ||
|
||
__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", | ||
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", | ||
"svdvals", "vector_norm", "diagonal"] | ||
|
||
del get_xp | ||
del da | ||
del _linalg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# FFT isn't conformant | ||
array_api_tests/test_fft.py | ||
|
||
# slow and not implemented in dask | ||
array_api_tests/test_linalg.py::test_matrix_power |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dask supports cupy on the GPU.
Is this something we also need to take into consideration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, it does support CuPy. I"m not sure that a
stream
input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole
to_device
method doesn't quite work for Dask? @jakirkham any thoughts on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy
DaskDevice
object that serves as the device for dask arrays.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lithomas1 what's the status of this comment? Does the
device()
helper above need to be updated for dask?