-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add .transpose method to Quantity (#128)
A user would like to be able to change the dimension ordering of Quantity objects, since the wrapper and new dynamical core use different dimension ordering for many (all?) values. Currently it requires significantly complicated user code to do this. This PR adds a `.transpose(dims)` method to Quantity which allows performing this re-ordering much more easily. If you know you are working with cell-centered variables, you can do: ```python3 from fv3gfs.util import X_DIM, Y_DIM, Z_DIM transposed_quantity = quantity.transpose([X_DIM, Y_DIM, Z_DIM]) ``` To support re-ordering without checking whether quantities are on cell centers or interfaces, the API supports giving a list of dimension names for dimensions. For example, to re-order to X-Y-Z dimensions regardless of the grid the variable is on, one could do: ```python3 from fv3gfs.util import X_DIMS, Y_DIMS, Z_DIMS transposed_quantity = quantity.transpose([X_DIMS, Y_DIMS, Z_DIMS]) ```
- Loading branch information
Jeremy McGibbon
authored
Aug 28, 2020
1 parent
caac227
commit ae6dae8
Showing
3 changed files
with
272 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import fv3gfs.util | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def initial_dims(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def initial_shape(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def initial_data(initial_shape, numpy): | ||
return numpy.random.randn(*initial_shape) | ||
|
||
|
||
@pytest.fixture | ||
def quantity_data_input(initial_data, numpy, backend): | ||
if "gt4py" in backend: | ||
array = numpy.empty(initial_data.shape) | ||
array[:] = initial_data | ||
else: | ||
array = initial_data | ||
print(type(array)) | ||
return array | ||
|
||
|
||
@pytest.fixture | ||
def initial_origin(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def initial_extent(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def transpose_order(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def target_dims(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def final_dims(initial_dims, transpose_order): | ||
return tuple(initial_dims[index] for index in transpose_order) | ||
|
||
|
||
@pytest.fixture | ||
def final_origin(initial_origin, transpose_order): | ||
return tuple(initial_origin[index] for index in transpose_order) | ||
|
||
|
||
@pytest.fixture | ||
def final_extent(initial_extent, transpose_order): | ||
return tuple(initial_extent[index] for index in transpose_order) | ||
|
||
|
||
@pytest.fixture | ||
def final_data(initial_data, transpose_order, numpy): | ||
return numpy.transpose(initial_data, transpose_order) | ||
|
||
|
||
@pytest.fixture | ||
def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): | ||
return fv3gfs.util.Quantity( | ||
quantity_data_input, | ||
dims=initial_dims, | ||
units="unit_string", | ||
origin=initial_origin, | ||
extent=initial_extent, | ||
) | ||
|
||
|
||
def param_product(*param_lists): | ||
return_list = [] | ||
if len(param_lists) == 0: | ||
return [[]] | ||
else: | ||
for item in param_lists[0]: | ||
for later_items in param_product(*param_lists[1:]): | ||
return_list.append([item] + later_items) | ||
return return_list | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order", | ||
[ | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), | ||
(6, 7), | ||
(1, 2), | ||
(2, 3), | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), | ||
(0, 1), | ||
id="2d_keep_order", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), | ||
(6, 7), | ||
(1, 2), | ||
(2, 3), | ||
(fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM), | ||
(1, 0), | ||
id="2d_transpose", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), | ||
(6, 7, 8), | ||
(1, 2, 3), | ||
(2, 3, 4), | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), | ||
(0, 1, 2), | ||
id="3d_keep_order", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), | ||
(6, 7, 8), | ||
(1, 2, 3), | ||
(2, 3, 4), | ||
(fv3gfs.util.X_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.Z_DIMS), | ||
(0, 1, 2), | ||
id="3d_keep_order_list_dims", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), | ||
(6, 7, 8), | ||
(1, 2, 3), | ||
(2, 3, 4), | ||
(fv3gfs.util.Z_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.X_DIM), | ||
(2, 1, 0), | ||
id="3d_transpose", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM, fv3gfs.util.Z_DIM), | ||
(6, 7, 8), | ||
(1, 2, 3), | ||
(2, 3, 4), | ||
(fv3gfs.util.Z_DIMS, fv3gfs.util.Y_DIMS, fv3gfs.util.X_DIMS), | ||
(2, 1, 0), | ||
id="3d_transpose_list_dims", | ||
), | ||
], | ||
indirect=True, | ||
) | ||
@pytest.mark.parametrize("backend", ["gt4py_numpy", "gt4py_cupy"], indirect=True) | ||
def test_transpose( | ||
quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy | ||
): | ||
result = quantity.transpose(target_dims) | ||
numpy.testing.assert_array_equal(result.data, final_data) | ||
assert result.dims == final_dims | ||
assert result.origin == final_origin | ||
assert result.extent == final_extent | ||
assert result.units == quantity.units | ||
assert result.gt4py_backend == quantity.gt4py_backend | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"initial_dims, initial_shape, initial_origin, initial_extent, target_dims, transpose_order", | ||
[ | ||
pytest.param( | ||
(fv3gfs.util.X_DIM,), (6,), (1,), (2,), (fv3gfs.util.Y_DIM,), (0,), id="1d" | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_INTERFACE_DIM), | ||
(6, 7), | ||
(1, 2), | ||
(2, 3), | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), | ||
(0, 1), | ||
id="2d_switch_stagger", | ||
), | ||
pytest.param( | ||
(fv3gfs.util.X_DIM, fv3gfs.util.Y_DIM), | ||
(6, 7), | ||
(1, 2), | ||
(2, 3), | ||
(fv3gfs.util.Y_DIM, fv3gfs.util.X_INTERFACE_DIM), | ||
(1, 0), | ||
id="2d_transpose_switch_stagger", | ||
), | ||
], | ||
indirect=True, | ||
) | ||
def test_transpose_invalid_cases( | ||
quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy | ||
): | ||
with pytest.raises(ValueError): | ||
quantity.transpose(target_dims) |