Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: make dask a hard requirement for unyt.dask_array #353

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions unyt/dask_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,13 @@
from functools import wraps

import numpy as np
import pytest
chrishavlin marked this conversation as resolved.
Show resolved Hide resolved

import unyt.array as ua
from unyt._on_demand_imports import _dask as dask

__doctest_requires__ = {
("unyt_from_dask", "reduce_with_units", "unyt_dask_array.to_dask"): ["dask"],
}


if dask.__is_available__:
_dask_Array = dask.array.core.Array
_dask_finalize = dask.array.core.finalize
else:
_dask_Array, _dask_finalize = object, None
pytest.importorskip("dask")
del pytest
from dask.array.core import Array as DaskArray, finalize as dask_finalize # noqa: E402

# the following attributes hang off of dask.array.core.Array and do not modify units
_use_unary_decorator = {
Expand Down Expand Up @@ -191,7 +184,7 @@ def wrapper(self, *args, **kwargs):
return wrapper


class unyt_dask_array(_dask_Array):
class unyt_dask_array(DaskArray):
"""
a dask.array.core.Array subclass that tracks units. This class is only
recommended for advanced usage, most cases should use the unyt_from_dask
Expand Down Expand Up @@ -335,11 +328,12 @@ def to_dask(self):
>>> x = da.random.random((10000, 10000), chunks=(1000, 1000))
>>> x_da = dask_array.unyt_from_dask(x, 'm')
>>> x_da.to_dask()
... # doctest: +NORMALIZE_WHITESPACE
dask.array<random_sample, shape=(10000, 10000), dtype=float64,
chunksize=(1000, 1000), chunktype=numpy.ndarray>
"""
(_, args) = super().__reduce__()
return _dask_Array(*args)
return DaskArray(*args)

def __reduce__(self):
(_, args) = super().__reduce__()
Expand Down Expand Up @@ -498,7 +492,7 @@ def _finalize_unyt(results, unit_name):
# here, we first call the standard finalize function for a dask array
# and then return a standard unyt_array from the now in-memory result if
# the result is an array, otherwise return a unyt_quantity.
result = _dask_finalize(results)
result = dask_finalize(results)

if type(result) == np.ndarray:
return ua.unyt_array(result, unit_name)
Expand Down Expand Up @@ -656,10 +650,11 @@ def reduce_with_units(dask_func, unyt_dask_in, *args, **kwargs):

Examples
--------
>>> from unyt import dask_array
>>> a = dask_array.dask.array.ones((10000,), chunks=(100,))
>>> a = dask_array.unyt_from_dask(a, 'm')
>>> b = dask_array.reduce_with_units(dask_array.dask.array.median, a, axis=0)
>>> import dask.array
>>> from unyt.dask_array import unyt_from_dask, reduce_with_units
>>> a = dask.array.ones((10000,), chunks=(100,))
>>> a = unyt_from_dask(a, 'm')
>>> b = reduce_with_units(dask.array.median, a, axis=0)
>>> b.compute()
unyt_quantity(1., 'm')

Expand Down