diff --git a/external/fv3gfs-util/fv3gfs/util/quantity.py b/external/fv3gfs-util/fv3gfs/util/quantity.py index b5a73edda..812e8897f 100644 --- a/external/fv3gfs-util/fv3gfs/util/quantity.py +++ b/external/fv3gfs-util/fv3gfs/util/quantity.py @@ -1,4 +1,4 @@ -from typing import Tuple, Iterable, Dict, Union +from typing import Tuple, Iterable, Dict, Union, Sequence from types import ModuleType import warnings import dataclasses @@ -298,7 +298,7 @@ def from_data_array( data_array: xr.DataArray, origin: Iterable[int] = None, extent: Iterable[int] = None, - ): + ) -> "Quantity": """ Initialize a Quantity from an xarray.DataArray. @@ -427,6 +427,76 @@ def data_array(self) -> xr.DataArray: def np(self) -> ModuleType: return self.metadata.np + def transpose(self, target_dims: Sequence[Union[str, Iterable[str]]]) -> "Quantity": + """Change the dimension order of this Quantity. + + If you know you are working with cell-centered variables, you can do: + + >>> from fv3gfs.util import X_DIM, Y_DIM, Z_DIM + >>> transposed_quantity = quantity.transpose([X_DIM, Y_DIM, Z_DIM]) + + To support re-ordering without checking whether quantities are on + cell centers or interfaces, the API supports giving a list of dimension names + for dimensions. For example, to re-order to X-Y-Z dimensions regardless of the + grid the variable is on, one could do: + + >>> from fv3gfs.util import X_DIMS, Y_DIMS, Z_DIMS + >>> transposed_quantity = quantity.transpose([X_DIMS, Y_DIMS, Z_DIMS]) + + Args: + target_dims: a list of output dimensions. Instead of a single dimension + name, an iterable of dimensions can be used instead for any entries. + For example, you may want to use fv3gfs.util.X_DIMS to place an + x-dimension without knowing whether it is on cell centers or interfaces. + + Returns: + transposed: Quantity with the requested output dimension order + + Raises: + ValueError: if any of the target dimensions do not exist on this Quantity, + or if this Quantity contains multiple values from an iterable entry + """ + target_dims = _collapse_dims(target_dims, self.dims) + transpose_order = [self.dims.index(dim) for dim in target_dims] + return Quantity( + self.np.transpose(self.data, transpose_order), + dims=transpose_sequence(self.dims, transpose_order), + units=self.units, + origin=transpose_sequence(self.origin, transpose_order), + extent=transpose_sequence(self.extent, transpose_order), + gt4py_backend=self.gt4py_backend, + ) + + +def transpose_sequence(sequence, order): + return sequence.__class__(sequence[i] for i in order) + + +def _collapse_dims(target_dims, dims): + return_list = [] + for target in target_dims: + if isinstance(target, str): + if target in dims: + return_list.append(target) + else: + raise ValueError( + f"requested dimension {target} is not defined in " + f"quantity dimensions {dims}" + ) + elif isinstance(target, Iterable): + matches = [d for d in target if d in dims] + if len(matches) > 1: + raise ValueError( + f"multiple matches for {target} found in quantity dimensions {dims}" + ) + elif len(matches) == 0: + raise ValueError( + f"no matches for {target} found in quantity dimensions {dims}" + ) + else: + return_list.append(matches[0]) + return return_list + def fill_index(index, length): return tuple(index) + (slice(None, None, None),) * (length - len(index)) diff --git a/external/fv3gfs-util/tests/quantity/test_storage.py b/external/fv3gfs-util/tests/quantity/test_storage.py index d939bdba0..c10871185 100644 --- a/external/fv3gfs-util/tests/quantity/test_storage.py +++ b/external/fv3gfs-util/tests/quantity/test_storage.py @@ -74,14 +74,10 @@ def test_numpy(quantity, backend): @pytest.mark.parametrize("backend", ["gt4py_numpy", "gt4py_cupy"], indirect=True) def test_storage_exists(quantity, backend): - if gt4py is None: - with pytest.raises(ImportError): - quantity.storage + if "numpy" in backend: + assert isinstance(quantity.storage, gt4py.storage.storage.CPUStorage) else: - if "numpy" in backend: - assert isinstance(quantity.storage, gt4py.storage.storage.CPUStorage) - else: - assert isinstance(quantity.storage, gt4py.storage.storage.GPUStorage) + assert isinstance(quantity.storage, gt4py.storage.storage.GPUStorage) @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) diff --git a/external/fv3gfs-util/tests/quantity/test_transpose.py b/external/fv3gfs-util/tests/quantity/test_transpose.py new file mode 100644 index 000000000..357730977 --- /dev/null +++ b/external/fv3gfs-util/tests/quantity/test_transpose.py @@ -0,0 +1,197 @@ +import fv3gfs.util +import pytest + + +@pytest.fixture +def initial_dims(request): + return request.param + + +@pytest.fixture +def initial_shape(request): + return request.param + + +@pytest.fixture +def initial_data(initial_shape, numpy): + return numpy.random.randn(*initial_shape) + + +@pytest.fixture +def quantity_data_input(initial_data, numpy, backend): + if "gt4py" in backend: + array = numpy.empty(initial_data.shape) + array[:] = initial_data + else: + array = initial_data + print(type(array)) + return array + + +@pytest.fixture +def initial_origin(request): + return request.param + + +@pytest.fixture +def initial_extent(request): + return request.param + + +@pytest.fixture +def transpose_order(request): + return request.param + + +@pytest.fixture +def target_dims(request): + return request.param + + +@pytest.fixture +def final_dims(initial_dims, transpose_order): + return tuple(initial_dims[index] for index in transpose_order) + + +@pytest.fixture +def final_origin(initial_origin, transpose_order): + return tuple(initial_origin[index] for index in transpose_order) + + +@pytest.fixture +def final_extent(initial_extent, transpose_order): + return tuple(initial_extent[index] for index in transpose_order) + + +@pytest.fixture +def final_data(initial_data, transpose_order, numpy): + return numpy.transpose(initial_data, transpose_order) + + +@pytest.fixture +def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): + return fv3gfs.util.Quantity( + quantity_data_input, + dims=initial_dims, + units="unit_string", + origin=initial_origin, + extent=initial_extent, + ) + + +def param_product(*param_lists): + return_list = [] + if len(param_lists) == 0: + return [[]] + else: + for item in param_lists[0]: + for later_items in param_product(*param_lists[1:]): + return_list.append([item] + later_items) + return return_list + + +@pytest.mark.parametrize( + "initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order", + [ + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), + (6, 7), + (1, 2), + (2, 3), + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), + (0, 1), + id="2d_keep_order", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), + (6, 7), + (1, 2), + (2, 3), + (fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM), + (1, 0), + id="2d_transpose", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), + (6, 7, 8), + (1, 2, 3), + (2, 3, 4), + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), + (0, 1, 2), + id="3d_keep_order", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), + (6, 7, 8), + (1, 2, 3), + (2, 3, 4), + (fv3gfs.util.X_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.Z_DIMS), + (0, 1, 2), + id="3d_keep_order_list_dims", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), + (6, 7, 8), + (1, 2, 3), + (2, 3, 4), + (fv3gfs.util.Z_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM), + (2, 1, 0), + id="3d_transpose", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), + (6, 7, 8), + (1, 2, 3), + (2, 3, 4), + (fv3gfs.util.Z_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.X_DIMS), + (2, 1, 0), + id="3d_transpose_list_dims", + ), + ], + indirect=True, +) +@pytest.mark.parametrize("backend", ["gt4py_numpy", "gt4py_cupy"], indirect=True) +def test_transpose( + quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy +): + result = quantity.transpose(target_dims) + numpy.testing.assert_array_equal(result.data, final_data) + assert result.dims == final_dims + assert result.origin == final_origin + assert result.extent == final_extent + assert result.units == quantity.units + assert result.gt4py_backend == quantity.gt4py_backend + + +@pytest.mark.parametrize( + "initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order", + [ + pytest.param( + (fv3gfs.util.X_DIM,), (6,), (1,), (2,), (fv3gfs.util.Y_DIM,), (0,), id="1d" + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_INTERFACE_DIM), + (6, 7), + (1, 2), + (2, 3), + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), + (0, 1), + id="2d_switch_stagger", + ), + pytest.param( + (fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), + (6, 7), + (1, 2), + (2, 3), + (fv3gfs.util.Y_DIM, fv3gfs.util.X_INTERFACE_DIM), + (1, 0), + id="2d_transpose_switch_stagger", + ), + ], + indirect=True, +) +def test_transpose_invalid_cases( + quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy +): + with pytest.raises(ValueError): + quantity.transpose(target_dims)