Skip to content

Commit

Permalink
Add .transpose method to Quantity (#128)
Browse files Browse the repository at this point in the history
A user would like to be able to change the dimension ordering of Quantity objects, since the wrapper and new dynamical core use different dimension ordering for many (all?) values. Currently it requires significantly complicated user code to do this. This PR adds a `.transpose(dims)` method to Quantity which allows performing this re-ordering much more easily.

If you know you are working with cell-centered variables, you can do:

```python3
from fv3gfs.util import X_DIM, Y_DIM, Z_DIM
transposed_quantity = quantity.transpose([X_DIM, Y_DIM, Z_DIM])
```

To support re-ordering without checking whether quantities are on cell centers or interfaces, the API supports giving a list of dimension names for dimensions. For example, to re-order to X-Y-Z dimensions regardless of the grid the variable is on, one could do:
```python3
from fv3gfs.util import X_DIMS, Y_DIMS, Z_DIMS
transposed_quantity = quantity.transpose([X_DIMS, Y_DIMS, Z_DIMS])
```
  • Loading branch information
Jeremy McGibbon authored Aug 28, 2020
1 parent caac227 commit ae6dae8
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 9 deletions.
74 changes: 72 additions & 2 deletions external/fv3gfs-util/fv3gfs/util/quantity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Iterable, Dict, Union
from typing import Tuple, Iterable, Dict, Union, Sequence
from types import ModuleType
import warnings
import dataclasses
Expand Down Expand Up @@ -298,7 +298,7 @@ def from_data_array(
data_array: xr.DataArray,
origin: Iterable[int] = None,
extent: Iterable[int] = None,
):
) -> "Quantity":
"""
Initialize a Quantity from an xarray.DataArray.
Expand Down Expand Up @@ -427,6 +427,76 @@ def data_array(self) -> xr.DataArray:
def np(self) -> ModuleType:
return self.metadata.np

def transpose(self, target_dims: Sequence[Union[str, Iterable[str]]]) -> "Quantity":
"""Change the dimension order of this Quantity.
If you know you are working with cell-centered variables, you can do:
>>> from fv3gfs.util import X_DIM, Y_DIM, Z_DIM
>>> transposed_quantity = quantity.transpose([X_DIM, Y_DIM, Z_DIM])
To support re-ordering without checking whether quantities are on
cell centers or interfaces, the API supports giving a list of dimension names
for dimensions. For example, to re-order to X-Y-Z dimensions regardless of the
grid the variable is on, one could do:
>>> from fv3gfs.util import X_DIMS, Y_DIMS, Z_DIMS
>>> transposed_quantity = quantity.transpose([X_DIMS, Y_DIMS, Z_DIMS])
Args:
target_dims: a list of output dimensions. Instead of a single dimension
name, an iterable of dimensions can be used instead for any entries.
For example, you may want to use fv3gfs.util.X_DIMS to place an
x-dimension without knowing whether it is on cell centers or interfaces.
Returns:
transposed: Quantity with the requested output dimension order
Raises:
ValueError: if any of the target dimensions do not exist on this Quantity,
or if this Quantity contains multiple values from an iterable entry
"""
target_dims = _collapse_dims(target_dims, self.dims)
transpose_order = [self.dims.index(dim) for dim in target_dims]
return Quantity(
self.np.transpose(self.data, transpose_order),
dims=transpose_sequence(self.dims, transpose_order),
units=self.units,
origin=transpose_sequence(self.origin, transpose_order),
extent=transpose_sequence(self.extent, transpose_order),
gt4py_backend=self.gt4py_backend,
)


def transpose_sequence(sequence, order):
return sequence.__class__(sequence[i] for i in order)


def _collapse_dims(target_dims, dims):
return_list = []
for target in target_dims:
if isinstance(target, str):
if target in dims:
return_list.append(target)
else:
raise ValueError(
f"requested dimension {target} is not defined in "
f"quantity dimensions {dims}"
)
elif isinstance(target, Iterable):
matches = [d for d in target if d in dims]
if len(matches) > 1:
raise ValueError(
f"multiple matches for {target} found in quantity dimensions {dims}"
)
elif len(matches) == 0:
raise ValueError(
f"no matches for {target} found in quantity dimensions {dims}"
)
else:
return_list.append(matches[0])
return return_list


def fill_index(index, length):
return tuple(index) + (slice(None, None, None),) * (length - len(index))
Expand Down
10 changes: 3 additions & 7 deletions external/fv3gfs-util/tests/quantity/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,10 @@ def test_numpy(quantity, backend):

@pytest.mark.parametrize("backend", ["gt4py_numpy", "gt4py_cupy"], indirect=True)
def test_storage_exists(quantity, backend):
if gt4py is None:
with pytest.raises(ImportError):
quantity.storage
if "numpy" in backend:
assert isinstance(quantity.storage, gt4py.storage.storage.CPUStorage)
else:
if "numpy" in backend:
assert isinstance(quantity.storage, gt4py.storage.storage.CPUStorage)
else:
assert isinstance(quantity.storage, gt4py.storage.storage.GPUStorage)
assert isinstance(quantity.storage, gt4py.storage.storage.GPUStorage)


@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True)
Expand Down
197 changes: 197 additions & 0 deletions external/fv3gfs-util/tests/quantity/test_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import fv3gfs.util
import pytest


@pytest.fixture
def initial_dims(request):
return request.param


@pytest.fixture
def initial_shape(request):
return request.param


@pytest.fixture
def initial_data(initial_shape, numpy):
return numpy.random.randn(*initial_shape)


@pytest.fixture
def quantity_data_input(initial_data, numpy, backend):
if "gt4py" in backend:
array = numpy.empty(initial_data.shape)
array[:] = initial_data
else:
array = initial_data
print(type(array))
return array


@pytest.fixture
def initial_origin(request):
return request.param


@pytest.fixture
def initial_extent(request):
return request.param


@pytest.fixture
def transpose_order(request):
return request.param


@pytest.fixture
def target_dims(request):
return request.param


@pytest.fixture
def final_dims(initial_dims, transpose_order):
return tuple(initial_dims[index] for index in transpose_order)


@pytest.fixture
def final_origin(initial_origin, transpose_order):
return tuple(initial_origin[index] for index in transpose_order)


@pytest.fixture
def final_extent(initial_extent, transpose_order):
return tuple(initial_extent[index] for index in transpose_order)


@pytest.fixture
def final_data(initial_data, transpose_order, numpy):
return numpy.transpose(initial_data, transpose_order)


@pytest.fixture
def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent):
return fv3gfs.util.Quantity(
quantity_data_input,
dims=initial_dims,
units="unit_string",
origin=initial_origin,
extent=initial_extent,
)


def param_product(*param_lists):
return_list = []
if len(param_lists) == 0:
return [[]]
else:
for item in param_lists[0]:
for later_items in param_product(*param_lists[1:]):
return_list.append([item] + later_items)
return return_list


@pytest.mark.parametrize(
"initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order",
[
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM),
(6, 7),
(1, 2),
(2, 3),
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM),
(0, 1),
id="2d_keep_order",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM),
(6, 7),
(1, 2),
(2, 3),
(fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM),
(1, 0),
id="2d_transpose",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM),
(6, 7, 8),
(1, 2, 3),
(2, 3, 4),
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM),
(0, 1, 2),
id="3d_keep_order",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM),
(6, 7, 8),
(1, 2, 3),
(2, 3, 4),
(fv3gfs.util.X_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.Z_DIMS),
(0, 1, 2),
id="3d_keep_order_list_dims",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM),
(6, 7, 8),
(1, 2, 3),
(2, 3, 4),
(fv3gfs.util.Z_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM),
(2, 1, 0),
id="3d_transpose",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM),
(6, 7, 8),
(1, 2, 3),
(2, 3, 4),
(fv3gfs.util.Z_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.X_DIMS),
(2, 1, 0),
id="3d_transpose_list_dims",
),
],
indirect=True,
)
@pytest.mark.parametrize("backend", ["gt4py_numpy", "gt4py_cupy"], indirect=True)
def test_transpose(
quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy
):
result = quantity.transpose(target_dims)
numpy.testing.assert_array_equal(result.data, final_data)
assert result.dims == final_dims
assert result.origin == final_origin
assert result.extent == final_extent
assert result.units == quantity.units
assert result.gt4py_backend == quantity.gt4py_backend


@pytest.mark.parametrize(
"initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order",
[
pytest.param(
(fv3gfs.util.X_DIM,), (6,), (1,), (2,), (fv3gfs.util.Y_DIM,), (0,), id="1d"
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_INTERFACE_DIM),
(6, 7),
(1, 2),
(2, 3),
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM),
(0, 1),
id="2d_switch_stagger",
),
pytest.param(
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM),
(6, 7),
(1, 2),
(2, 3),
(fv3gfs.util.Y_DIM, fv3gfs.util.X_INTERFACE_DIM),
(1, 0),
id="2d_transpose_switch_stagger",
),
],
indirect=True,
)
def test_transpose_invalid_cases(
quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy
):
with pytest.raises(ValueError):
quantity.transpose(target_dims)

0 comments on commit ae6dae8

Please sign in to comment.