Skip to content

Commit

Permalink
Extend regridder saving/loading to all regridders (#357)
Browse files Browse the repository at this point in the history
* extend regridder saving/loading

* fix bug

* flake8

* add tests and documentation

* generalise existing tests

* review comments, test for _managed_var_name

* address review comment

* add changelog

* flake 8
  • Loading branch information
stephenworsley authored May 30, 2024
1 parent cf86c1d commit d4f4c2d
Show file tree
Hide file tree
Showing 8 changed files with 526 additions and 93 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## [Unreleased]

### Added

- [PR#357](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/357)
Added support for saving and loading of `ESMFAreaWeighted`, `ESMFBilinear`
and `ESMFNearest` regridders.
[@stephenworsley](https://github.com/stephenworsley)

### Changed

- [PR#361](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/361)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/userguide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ Saving and Loading a Regridder
A regridder can be set up for reuse, this saves time performing the
computationally expensive initialisation process::

from esmf_regrid.experimental.unstructured_scheme import MeshToGridESMFRegridder
from esmf_regrid.experimental.unstructured_scheme import ESMFAreaWeighted

# Initialise the regridder with a source mesh and target grid.
regridder = MeshToGridESMFRegridder(source_mesh_cube, target_grid_cube)
regridder = ESMFAreaWeighted().regridder(source_mesh_cube, target_grid_cube)

# use the initialised regridder to regrid the data from the source cube
# onto a cube with the same grid as `target_grid_cube`.
Expand Down
10 changes: 6 additions & 4 deletions docs/src/userguide/scheme_comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ These were formerly the only way to do regridding with a source or
target cube defined on an unstructured mesh. These are less flexible and
require that the source/target be defined on a grid/mesh. Unlike the above
regridders whose method is fixed, these regridders take a ``method`` keyword
of ``conservative``, ``bilinear`` or ``nearest``. While most of the
functionality in these regridders have been ported into the above schemes and
regridders, these remain the only regridders capable of being saved and loaded by
:mod:`esmf_regrid.experimental.io`.
of ``conservative``, ``bilinear`` or ``nearest``. All the
functionality in these regridders has now been ported into the above schemes and
regridders. Before version 0.10, these were the only regridders capable of being
saved and loaded by :mod:`esmf_regrid.experimental.io`, so while the above generic
regridders are recomended, these regridders are still available for the sake of
consistency with regridders saved from older versions.


Overview: Miscellaneous Functions
Expand Down
176 changes: 145 additions & 31 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Provides load/save functions for regridders."""

from contextlib import contextmanager

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
Expand All @@ -13,9 +15,19 @@
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import (
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
GridRecord,
MeshRecord,
)


SUPPORTED_REGRIDDERS = [
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
]
Expand All @@ -34,6 +46,8 @@
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"
SOURCE_RESOLUTION = "src_resolution"
TARGET_RESOLUTION = "tgt_resolution"


def _add_mask_to_cube(mask, cube, name):
Expand All @@ -43,18 +57,63 @@ def _add_mask_to_cube(mask, cube, name):
cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


@contextmanager
def _managed_var_name(src_cube, tgt_cube):
src_coord_names = []
src_mesh_coords = []
if src_cube.mesh is not None:
src_mesh = src_cube.mesh
src_mesh_coords = src_mesh.coords()
for coord in src_mesh_coords:
src_coord_names.append(coord.var_name)
tgt_coord_names = []
tgt_mesh_coords = []
if tgt_cube.mesh is not None:
tgt_mesh = tgt_cube.mesh
tgt_mesh_coords = tgt_mesh.coords()
for coord in tgt_mesh_coords:
tgt_coord_names.append(coord.var_name)

try:
for coord in src_mesh_coords:
coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()])
for coord in tgt_mesh_coords:
coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()])
yield None
finally:
for coord, var_name in zip(src_mesh_coords, src_coord_names):
coord.var_name = var_name
for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names):
coord.var_name = var_name


def _clean_var_names(cube):
cube.var_name = None
for coord in cube.coords():
coord.var_name = None
if cube.mesh is not None:
cube.mesh.var_name = None
for coord in cube.mesh.coords():
coord.var_name = None
for con in cube.mesh.connectivities():
con.var_name = None


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.
Saves either a
:class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`
or a
:class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`.
Saves any of the regridder classes, i.e.
:class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`,
:class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`,
:class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`,
:class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or
:class:`~esmf_regrid.schemes.ESMFNearestRegridder`.
.
Parameters
----------
rg : :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder` or :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`
rg : :class:`~esmf_regrid.schemes._ESMFRegridder`
The regridder instance to save.
filename : str
The file name to save to.
Expand All @@ -76,28 +135,56 @@ def _standard_grid_cube(grid, name):
cube.add_aux_coord(grid[1], [0, 1])
return cube

if regridder_type == "GridToMeshESMFRegridder":
def _standard_mesh_cube(mesh, location, name):
mesh_coords = mesh.to_MeshCoords(location)
data = np.zeros(mesh_coords[0].points.shape[0])
cube = Cube(data, var_name=name, long_name=name)
for coord in mesh_coords:
cube.add_aux_coord(coord, 0)
return cube

if regridder_type in [
"ESMFAreaWeightedRegridder",
"ESMFBilinearRegridder",
"ESMFNearestRegridder",
]:
src_grid = rg._src
if isinstance(src_grid, GridRecord):
src_cube = _standard_grid_cube(
(src_grid.grid_y, src_grid.grid_x), SOURCE_NAME
)
elif isinstance(src_grid, MeshRecord):
src_mesh, src_location = src_grid
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
else:
raise ValueError("Improper type for `rg._src`.")
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = rg._tgt
if isinstance(tgt_grid, GridRecord):
tgt_cube = _standard_grid_cube(
(tgt_grid.grid_y, tgt_grid.grid_x), TARGET_NAME
)
elif isinstance(tgt_grid, MeshRecord):
tgt_mesh, tgt_location = tgt_grid
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
else:
raise ValueError("Improper type for `rg._tgt`.")
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)
elif regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_cube = _standard_grid_cube(src_grid, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_mesh = rg.mesh
tgt_location = rg.location
tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location)
tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0])
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh_coords:
tgt_cube.add_aux_coord(coord, 0)
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)

elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_location = rg.location
src_mesh_coords = src_mesh.to_MeshCoords(src_location)
src_data = np.zeros(src_mesh_coords[0].points.shape[0])
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh_coords:
src_cube.add_aux_coord(coord, 0)
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = (rg.grid_y, rg.grid_x)
Expand All @@ -112,7 +199,18 @@ def _standard_grid_cube(grid, name):

method = str(check_method(rg.method).name)

resolution = rg.resolution
if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]:
resolution = rg.resolution
src_resolution = None
tgt_resolution = None
elif regridder_type == "ESMFAreaWeightedRegridder":
resolution = None
src_resolution = rg.src_resolution
tgt_resolution = rg.tgt_resolution
else:
resolution = None
src_resolution = None
tgt_resolution = None

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix)
Expand Down Expand Up @@ -141,6 +239,10 @@ def _standard_grid_cube(grid, name):
}
if resolution is not None:
attributes[RESOLUTION] = resolution
if src_resolution is not None:
attributes[SOURCE_RESOLUTION] = src_resolution
if tgt_resolution is not None:
attributes[TARGET_RESOLUTION] = tgt_resolution

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
Expand All @@ -158,17 +260,14 @@ def _standard_grid_cube(grid, name):
long_name=WEIGHTS_SHAPE_NAME,
)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
# Save cubes while ensuring var_names do not conflict for the sake of consistency.
with _managed_var_name(src_cube, tgt_cube):
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube])

for cube in cube_list:
cube.attributes = attributes
for cube in cube_list:
cube.attributes = attributes

iris.fileformats.netcdf.save(cube_list, filename)
iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
Expand All @@ -194,7 +293,9 @@ def load_regridder(filename):

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(SOURCE_NAME)
_clean_var_names(src_cube)
tgt_cube = cubes.extract_cube(TARGET_NAME)
_clean_var_names(tgt_cube)
weights_cube = cubes.extract_cube(WEIGHTS_NAME)
weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME)

Expand All @@ -210,8 +311,14 @@ def load_regridder(filename):
)

resolution = weights_cube.attributes.get(RESOLUTION, None)
src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None)
tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)
if src_resolution is not None:
src_resolution = int(src_resolution)
if tgt_resolution is not None:
tgt_resolution = int(tgt_resolution)

# Reconstruct the weight matrix.
weight_data = weights_cube.data
Expand All @@ -234,18 +341,25 @@ def load_regridder(filename):
use_tgt_mask = False

if scheme is GridToMeshESMFRegridder:
resolution_keyword = "src_resolution"
resolution_keyword = SOURCE_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is MeshToGridESMFRegridder:
resolution_keyword = "tgt_resolution"
resolution_keyword = TARGET_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is ESMFAreaWeightedRegridder:
kwargs = {
SOURCE_RESOLUTION: src_resolution,
TARGET_RESOLUTION: tgt_resolution,
"mdtol": mdtol,
}
elif scheme is ESMFBilinearRegridder:
kwargs = {"mdtol": mdtol}
else:
raise NotImplementedError
kwargs = {resolution_keyword: resolution}
kwargs = {}

regridder = scheme(
src_cube,
tgt_cube,
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
Expand Down
11 changes: 11 additions & 0 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,8 @@ def regridder(
self,
src_grid,
tgt_grid,
src_resolution=None,
tgt_resolution=None,
use_src_mask=None,
use_tgt_mask=None,
tgt_location="face",
Expand All @@ -998,6 +1000,11 @@ def regridder(
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
src_resolution, tgt_resolution : int, optional
If present, represents the amount of latitude slices per source/target cell
given to ESMF for calculation. If resolution is set, ``src`` and ``tgt``
respectively must have strictly increasing bounds (bounds may be transposed
plus or minus 360 degrees to make the bounds strictly increasing).
use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional
Array describing which elements :mod:`esmpy` will ignore on the src_grid.
If True, the mask will be derived from src_grid.
Expand Down Expand Up @@ -1035,6 +1042,8 @@ def regridder(
src_grid,
tgt_grid,
mdtol=self.mdtol,
src_resolution=src_resolution,
tgt_resolution=tgt_resolution,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
tgt_location="face",
Expand Down Expand Up @@ -1483,8 +1492,10 @@ def __init__(
if tgt_location is not "face".
"""
kwargs = dict()
self.src_resolution = src_resolution
if src_resolution is not None:
kwargs["src_resolution"] = src_resolution
self.tgt_resolution = tgt_resolution
if tgt_resolution is not None:
kwargs["tgt_resolution"] = tgt_resolution
if tgt_location is not None and tgt_location != "face":
Expand Down
Loading

0 comments on commit d4f4c2d

Please sign in to comment.