diff --git a/esmvalcore/preprocessor/_multimodel.py b/esmvalcore/preprocessor/_multimodel.py index 8c80731352..7f944016cd 100644 --- a/esmvalcore/preprocessor/_multimodel.py +++ b/esmvalcore/preprocessor/_multimodel.py @@ -10,6 +10,7 @@ import iris import numpy as np from iris.util import equalise_attributes + from esmvalcore.preprocessor import remove_fx_variables logger = logging.getLogger(__name__) @@ -221,14 +222,31 @@ def _combine(cubes): return merged_cube +def _compute_slices(cubes): + """Create cube slices resulting in a combined cube of about 1 GB.""" + gigabyte = 2**30 + total_bytes = cubes[0].data.nbytes * len(cubes) + n_slices = int(np.ceil(total_bytes / gigabyte)) + + n_timesteps = cubes[0].shape[0] + slice_len = int(np.ceil(n_timesteps / n_slices)) + + for i in range(n_slices): + start = i * slice_len + end = (i + 1) * slice_len + if end > n_timesteps: + end = n_timesteps + yield slice(start, end) + + def _compute_eager(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): """Compute statistics one slice at a time.""" _ = [cube.data for cube in cubes] # make sure the cubes' data are realized result_slices = [] - for i in range(cubes[0].shape[0]): - single_model_slices = [cube[i] for cube in cubes] + for chunk in _compute_slices(cubes): + single_model_slices = [cube[chunk] for cube in cubes] combined_slice = _combine(single_model_slices) with warnings.catch_warnings(): warnings.filterwarnings( @@ -250,7 +268,7 @@ def _compute_eager(cubes: list, *, operator: iris.analysis.Aggregator, result_slices.append(collapsed_slice) try: - result_cube = iris.cube.CubeList(result_slices).merge_cube() + result_cube = iris.cube.CubeList(result_slices).concatenate_cube() except Exception as excinfo: raise ValueError( "Multi-model statistics failed to concatenate results into a" diff --git a/tests/unit/preprocessor/_multimodel/test_multimodel.py b/tests/unit/preprocessor/_multimodel/test_multimodel.py index 98d17b0582..a7d2ed9552 100644 --- a/tests/unit/preprocessor/_multimodel/test_multimodel.py +++ b/tests/unit/preprocessor/_multimodel/test_multimodel.py @@ -11,8 +11,8 @@ from iris.cube import Cube import esmvalcore.preprocessor._multimodel as mm -from esmvalcore.preprocessor._ancillary_vars import add_ancillary_variable from esmvalcore.preprocessor import multi_model_statistics +from esmvalcore.preprocessor._ancillary_vars import add_ancillary_variable SPAN_OPTIONS = ('overlap', 'full') @@ -131,7 +131,7 @@ def get_cubes_for_validation_test(frequency, lazy=False): ('full', 'median', (5, 5, 3)), ('full', 'p50', (5, 5, 3)), ('full', 'p99.5', (8.96, 8.96, 4.98)), - ('full', 'peak', ([9], [9], [5])), + ('full', 'peak', (9, 9, 5)), ('overlap', 'mean', (5, 5)), ('overlap', 'std_dev', (5.656854249492381, 4)), ('overlap', 'std', (5.656854249492381, 4)), @@ -140,13 +140,31 @@ def get_cubes_for_validation_test(frequency, lazy=False): ('overlap', 'median', (5, 5)), ('overlap', 'p50', (5, 5)), ('overlap', 'p99.5', (8.96, 8.96)), - ('overlap', 'peak', ([9], [9])), + ('overlap', 'peak', (9, 9)), # test multiple statistics ('overlap', ('min', 'max'), ((1, 1), (9, 9))), ('full', ('min', 'max'), ((1, 1, 1), (9, 9, 5))), ) +@pytest.mark.parametrize( + 'length,slices', + [ + (1, [slice(0, 1)]), + (25000, [slice(0, 8334), + slice(8334, 16668), + slice(16668, 25000)]), + ], +) +def test_compute_slices(length, slices): + """Test cube `_compute_slices`.""" + cubes = [ + Cube(da.empty([length, 50, 100], dtype=np.float32)) for _ in range(5) + ] + result = list(mm._compute_slices(cubes)) + assert result == slices + + @pytest.mark.parametrize('frequency', FREQUENCY_OPTIONS) @pytest.mark.parametrize('span, statistics, expected', VALIDATION_DATA_SUCCESS) def test_multimodel_statistics(frequency, span, statistics, expected): @@ -494,6 +512,7 @@ def test_unify_time_coordinates(): class PreprocessorFile: """Mockup to test output of multimodel.""" + def __init__(self, cube=None): if cube: self.cubes = [cube] @@ -552,9 +571,8 @@ def test_ignore_tas_scalar_height_coord(): cube.add_aux_coord( iris.coords.AuxCoord([height], var_name="height", units="m")) - result = mm.multi_model_statistics([tas_2m, tas_2m.copy(), tas_1p5m], - statistics=['mean'], - span='full') + result = mm.multi_model_statistics( + [tas_2m, tas_2m.copy(), tas_1p5m], statistics=['mean'], span='full') # iris automatically averages the value of the scalar coordinate. assert len(result['mean'].coords("height")) == 1