From 71456ec31a6f87322abf3fe792c4717a88ad946d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 15 Jan 2023 12:37:11 +0100 Subject: [PATCH 1/3] ENH: make dask a hard requirement for unyt.dask_array --- unyt/dask_array.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/unyt/dask_array.py b/unyt/dask_array.py index 224b9a12..f4ca6e26 100644 --- a/unyt/dask_array.py +++ b/unyt/dask_array.py @@ -8,20 +8,13 @@ from functools import wraps import numpy as np +import pytest import unyt.array as ua -from unyt._on_demand_imports import _dask as dask -__doctest_requires__ = { - ("unyt_from_dask", "reduce_with_units", "unyt_dask_array.to_dask"): ["dask"], -} - - -if dask.__is_available__: - _dask_Array = dask.array.core.Array - _dask_finalize = dask.array.core.finalize -else: - _dask_Array, _dask_finalize = object, None +pytest.importorskip("dask") +del pytest +from dask.array.core import Array as DaskArray, finalize as dask_finalize # noqa: E402 # the following attributes hang off of dask.array.core.Array and do not modify units _use_unary_decorator = { @@ -191,7 +184,7 @@ def wrapper(self, *args, **kwargs): return wrapper -class unyt_dask_array(_dask_Array): +class unyt_dask_array(DaskArray): """ a dask.array.core.Array subclass that tracks units. This class is only recommended for advanced usage, most cases should use the unyt_from_dask @@ -339,7 +332,7 @@ def to_dask(self): chunksize=(1000, 1000), chunktype=numpy.ndarray> """ (_, args) = super().__reduce__() - return _dask_Array(*args) + return DaskArray(*args) def __reduce__(self): (_, args) = super().__reduce__() @@ -498,7 +491,7 @@ def _finalize_unyt(results, unit_name): # here, we first call the standard finalize function for a dask array # and then return a standard unyt_array from the now in-memory result if # the result is an array, otherwise return a unyt_quantity. - result = _dask_finalize(results) + result = dask_finalize(results) if type(result) == np.ndarray: return ua.unyt_array(result, unit_name) @@ -656,10 +649,11 @@ def reduce_with_units(dask_func, unyt_dask_in, *args, **kwargs): Examples -------- - >>> from unyt import dask_array - >>> a = dask_array.dask.array.ones((10000,), chunks=(100,)) - >>> a = dask_array.unyt_from_dask(a, 'm') - >>> b = dask_array.reduce_with_units(dask_array.dask.array.median, a, axis=0) + >>> import dask.array + >>> from unyt.dask_array import unyt_from_dask, reduce_with_units + >>> a = dask.array.ones((10000,), chunks=(100,)) + >>> a = unyt_from_dask(a, 'm') + >>> b = reduce_with_units(dask.array.median, a, axis=0) >>> b.compute() unyt_quantity(1., 'm') From 44a0a27cffaf58fda2340f3393f5644388e1dccb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Sun, 15 Jan 2023 12:55:18 +0100 Subject: [PATCH 2/3] TST: add missing doctest flag --- unyt/dask_array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unyt/dask_array.py b/unyt/dask_array.py index f4ca6e26..a8f2f531 100644 --- a/unyt/dask_array.py +++ b/unyt/dask_array.py @@ -328,6 +328,7 @@ def to_dask(self): >>> x = da.random.random((10000, 10000), chunks=(1000, 1000)) >>> x_da = dask_array.unyt_from_dask(x, 'm') >>> x_da.to_dask() + ... # doctest: +NORMALIZE_WHITESPACE dask.array """ From a39dd45116602ff52d6cbc208ecf0f7d8c7b6e0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 17 Jan 2023 17:38:54 +0100 Subject: [PATCH 3/3] TST: avoid making pytest a hard dependency to unyt.dask_array --- unyt/dask_array.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/unyt/dask_array.py b/unyt/dask_array.py index a8f2f531..5085a3ca 100644 --- a/unyt/dask_array.py +++ b/unyt/dask_array.py @@ -5,15 +5,24 @@ """ +import sys from functools import wraps import numpy as np -import pytest import unyt.array as ua -pytest.importorskip("dask") -del pytest +if "pytest" in sys.modules: + # should only happen if pytest is installed *and* already imported, + # so we can skip collecting doctests from this module when dask isn't installed + # while avoiding making pytest itself a hard dependency to this module. + # This check is constructed to work with direct invocation (pytest unyt) + # as well as through python -m pytest + import pytest + + pytest.importorskip("dask") + del pytest + from dask.array.core import Array as DaskArray, finalize as dask_finalize # noqa: E402 # the following attributes hang off of dask.array.core.Array and do not modify units