Skip to content

Commit

Permalink
Merge pull request #353 from neutrinoceros/require_dask_for_unyt.dask…
Browse files Browse the repository at this point in the history
…_array

ENH: make dask a hard requirement for unyt.dask_array
  • Loading branch information
ngoldbaum authored Jan 17, 2023
2 parents ebba99b + a39dd45 commit 44d0509
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions unyt/dask_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@
"""

import sys
from functools import wraps

import numpy as np

import unyt.array as ua
from unyt._on_demand_imports import _dask as dask

__doctest_requires__ = {
("unyt_from_dask", "reduce_with_units", "unyt_dask_array.to_dask"): ["dask"],
}
if "pytest" in sys.modules:
# should only happen if pytest is installed *and* already imported,
# so we can skip collecting doctests from this module when dask isn't installed
# while avoiding making pytest itself a hard dependency to this module.
# This check is constructed to work with direct invocation (pytest unyt)
# as well as through python -m pytest
import pytest

pytest.importorskip("dask")
del pytest

if dask.__is_available__:
_dask_Array = dask.array.core.Array
_dask_finalize = dask.array.core.finalize
else:
_dask_Array, _dask_finalize = object, None
from dask.array.core import Array as DaskArray, finalize as dask_finalize # noqa: E402

# the following attributes hang off of dask.array.core.Array and do not modify units
_use_unary_decorator = {
Expand Down Expand Up @@ -191,7 +193,7 @@ def wrapper(self, *args, **kwargs):
return wrapper


class unyt_dask_array(_dask_Array):
class unyt_dask_array(DaskArray):
"""
a dask.array.core.Array subclass that tracks units. This class is only
recommended for advanced usage, most cases should use the unyt_from_dask
Expand Down Expand Up @@ -335,11 +337,12 @@ def to_dask(self):
>>> x = da.random.random((10000, 10000), chunks=(1000, 1000))
>>> x_da = dask_array.unyt_from_dask(x, 'm')
>>> x_da.to_dask()
... # doctest: +NORMALIZE_WHITESPACE
dask.array<random_sample, shape=(10000, 10000), dtype=float64,
chunksize=(1000, 1000), chunktype=numpy.ndarray>
"""
(_, args) = super().__reduce__()
return _dask_Array(*args)
return DaskArray(*args)

def __reduce__(self):
(_, args) = super().__reduce__()
Expand Down Expand Up @@ -498,7 +501,7 @@ def _finalize_unyt(results, unit_name):
# here, we first call the standard finalize function for a dask array
# and then return a standard unyt_array from the now in-memory result if
# the result is an array, otherwise return a unyt_quantity.
result = _dask_finalize(results)
result = dask_finalize(results)

if type(result) == np.ndarray:
return ua.unyt_array(result, unit_name)
Expand Down Expand Up @@ -656,10 +659,11 @@ def reduce_with_units(dask_func, unyt_dask_in, *args, **kwargs):
Examples
--------
>>> from unyt import dask_array
>>> a = dask_array.dask.array.ones((10000,), chunks=(100,))
>>> a = dask_array.unyt_from_dask(a, 'm')
>>> b = dask_array.reduce_with_units(dask_array.dask.array.median, a, axis=0)
>>> import dask.array
>>> from unyt.dask_array import unyt_from_dask, reduce_with_units
>>> a = dask.array.ones((10000,), chunks=(100,))
>>> a = unyt_from_dask(a, 'm')
>>> b = reduce_with_units(dask.array.median, a, axis=0)
>>> b.compute()
unyt_quantity(1., 'm')
Expand Down

0 comments on commit 44d0509

Please sign in to comment.