Skip to content

Commit

Permalink
a bunch of refactoring that really came down to: if we're gonna warn,…
Browse files Browse the repository at this point in the history
… we have to turn off ignore_warnings
  • Loading branch information
keflavich committed Jan 6, 2025
1 parent beb1233 commit 4009ed1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
31 changes: 16 additions & 15 deletions spectral_cube/dask_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import warnings
import tempfile
import textwrap

from functools import wraps
from contextlib import contextmanager
Expand Down Expand Up @@ -67,6 +68,18 @@ def wrapper(self, *args, **kwargs):
return wrapper


def _warn_slow_dask(functionname):
"""
Dask has a different 'slow' warning than non-dask. It is only expected to
be slow for statistics that require sorting (and possibly cube arithmetic).
"""
warnings.warn(message=textwrap.dedent(f"""
Dask requires loading the whole cube into memory for {functionname}
calculations. This may result in slow computation.
""").strip(),
category=PossiblySlowWarning)


def add_save_to_tmp_dir_option(function):

@wraps(function)
Expand Down Expand Up @@ -246,15 +259,6 @@ def _data(self, value):
raise TypeError('_data should be set to a dask array')
self.__data = value

def _warn_slow(self, functionname):
"""
Dask has a different 'slow' warning than non-dask.
"""
warnings.warn(f"""
Dask requires loading the whole cube into memory for {functionname}
calculations. This may result in slow computation.
""", PossiblySlowWarning)

def use_dask_scheduler(self, scheduler, num_workers=None):
"""
Set the dask scheduler to use.
Expand Down Expand Up @@ -635,7 +639,6 @@ def mean(self, axis=None, **kwargs):
return self._compute(da.nanmean(self._get_filled_data(fill=np.nan), axis=axis, **kwargs))

@projection_if_needed
@ignore_warnings
def median(self, axis=None, **kwargs):
"""
Return the median of the cube, optionally over an axis.
Expand All @@ -645,13 +648,12 @@ def median(self, axis=None, **kwargs):
if axis is None:
# da.nanmedian raises NotImplementedError since it is not possible
# to do efficiently, so we use Numpy instead.
self._warn_slow('median')
_warn_slow_dask('median')
return np.nanmedian(self._compute(data), **kwargs)
else:
return self._compute(da.nanmedian(self._get_filled_data(fill=np.nan), axis=axis, **kwargs))

@projection_if_needed
@ignore_warnings
def percentile(self, q, axis=None, **kwargs):
"""
Return percentiles of the data.
Expand All @@ -669,7 +671,7 @@ def percentile(self, q, axis=None, **kwargs):
if axis is None:
# There is no way to compute the percentile of the whole array in
# chunks.
self._warn_slow('percentile')
_warn_slow_dask('percentile')
return np.nanpercentile(data, q, **kwargs)
else:
# Rechunk so that there is only one chunk along the desired axis
Expand All @@ -692,7 +694,6 @@ def std(self, axis=None, ddof=0, **kwargs):
return self._compute(da.nanstd(self._get_filled_data(fill=np.nan), axis=axis, ddof=ddof, **kwargs))

@projection_if_needed
@ignore_warnings
def mad_std(self, axis=None, ignore_nan=True, **kwargs):
"""
Use astropy's mad_std to compute the standard deviation
Expand All @@ -703,7 +704,7 @@ def mad_std(self, axis=None, ignore_nan=True, **kwargs):
if axis is None:
# In this case we have to load the full data - even dask's
# nanmedian doesn't work efficiently over the whole array.
self._warn_slow('mad_std')
_warn_slow_dask('mad_std')
return stats.mad_std(data, ignore_nan=ignore_nan, **kwargs)
else:
# Rechunk so that there is only one chunk along the desired axis
Expand Down
3 changes: 2 additions & 1 deletion spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def test_huge_disallowed(data_vda_jybeam_lower, use_dask):
assert cube._is_huge

if use_dask:
with pytest.raises(ValueError, match='whole cube into memory'):
with warnings.catch_warnings(record=True) as ww:
cube.mad_std()
assert 'whole cube into memory' in str(ww[0].message)
else:
with pytest.raises(ValueError, match='entire cube into memory'):
cube + 5*cube.unit
Expand Down

0 comments on commit 4009ed1

Please sign in to comment.