Skip to content

Commit

Permalink
updated for enums (#293)
Browse files Browse the repository at this point in the history
* updated for enums

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tried another form of import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* moved import constants to bottom; it's something like this

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* I've fixed this before, but I can't remember how

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* I added a note to my old one, hopefully this does the trick

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* testing this works

* hopefully fixed the type error problems

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated test to fix one error

* figuring it out

* fixed no method error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed NotImplementedError

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* missed one

* fixed flake8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed flake8 again

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed flake8 the third

* updated to include a checker function in constants, with necessary tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed flake8 conflicts

* missed one: fixed another flake8 conflict

* more flake8

* flake8 2: electric imperativemoodaloo

* more flake8

* fixed test failure

* fixed test failure

* fixed error messages to use enums

* hasn't this been a journey. Fixed error message

* actioned review comments

* fixed precommit and error message

* corrected constants error message

* corrected docstring

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ESadek-MO and pre-commit-ci[bot] authored Nov 30, 2023
1 parent be33a81 commit e83b7be
Show file tree
Hide file tree
Showing 18 changed files with 303 additions and 153 deletions.
2 changes: 2 additions & 0 deletions esmf_regrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
except ImportError:
raise exc

# constants needs to be above schemes, as it is used within
from .constants import Constants, check_method, check_norm
from .schemes import *


Expand Down
60 changes: 60 additions & 0 deletions esmf_regrid/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Holds all enums created for esmf-regrid."""

from enum import Enum

from . import esmpy


class Constants:
"""Encompassing class for best practice import."""

class Method(Enum):
"""holds enums for Method values."""

CONSERVATIVE = esmpy.RegridMethod.CONSERVE
BILINEAR = esmpy.RegridMethod.BILINEAR
NEAREST = esmpy.RegridMethod.NEAREST_STOD

class NormType(Enum):
"""holds enums for norm types."""

FRACAREA = esmpy.api.constants.NormType.FRACAREA
DSTAREA = esmpy.api.constants.NormType.DSTAREA


method_dict = {
"conservative": Constants.Method.CONSERVATIVE,
"bilinear": Constants.Method.BILINEAR,
"nearest": Constants.Method.NEAREST,
}

norm_dict = {
"fracarea": Constants.NormType.FRACAREA,
"dstarea": Constants.NormType.DSTAREA,
}


def check_method(method):
"""Check that method is a member of the `Constants.Method` enum or raise an error."""
if method in method_dict.keys():
result = method_dict[method]
elif method in method_dict.values():
result = method
else:
raise ValueError(
f"Method must be a member of `Constants.Method` enum, instead got {method}"
)
return result


def check_norm(norm):
"""Check that normtype is a member of the `Constants.NormType` enum or raise an error."""
if norm in norm_dict.keys():
result = norm_dict[norm]
elif norm in norm_dict.values():
result = norm
else:
raise ValueError(
f"NormType must be a member of `Constants.NormType` enum, instead got {norm}"
)
return result
47 changes: 18 additions & 29 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy.sparse

import esmf_regrid
from esmf_regrid import check_method, check_norm, Constants
from . import esmpy
from ._esmf_sdo import GridInfo, RefinedGridInfo

Expand Down Expand Up @@ -57,7 +58,9 @@ def _weights_dict_to_sparse_array(weights, shape, index_offsets):
class Regridder:
"""Regridder for directly interfacing with :mod:`esmpy`."""

def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
def __init__(
self, src, tgt, method=Constants.Method.CONSERVATIVE, precomputed_weights=None
):
"""
Create a regridder from descriptions of horizontal grids/meshes.
Expand All @@ -76,38 +79,25 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
Describes the target mesh/grid.
Data output by this regridder will be a :class:`numpy.ndarray` whose
shape is compatible with ``tgt``.
method : str
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE`,
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST_STOD` used to calculate weights.
method : :class:`Constants.Method`
The method to be used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
and ``precomputed_weights`` will be used as the regridding weights.
"""
self.src = src
self.tgt = tgt

if method == "conservative":
esmf_regrid_method = esmpy.RegridMethod.CONSERVE
elif method == "bilinear":
esmf_regrid_method = esmpy.RegridMethod.BILINEAR
elif method == "nearest":
esmf_regrid_method = esmpy.RegridMethod.NEAREST_STOD
else:
raise ValueError(
f"method must be either 'bilinear', 'conservative' or 'nearest', got '{method}'."
)
self.method = method
# type checks method
self.method = check_method(method)

self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = esmpy.__version__
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=esmf_regrid_method,
regrid_method=method.value,
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
Expand Down Expand Up @@ -144,19 +134,17 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
self.esmf_version = None
self.weight_matrix = precomputed_weights

def regrid(self, src_array, norm_type="fracarea", mdtol=1):
def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""
Perform regridding on an array of data.
Parameters
----------
src_array : :obj:`~numpy.typing.ArrayLike`
Array whose shape is compatible with ``self.src``
norm_type : str
Either ``fracarea`` or ``dstarea``, defaults to ``fracarea``. Determines the
type of normalisation applied to the weights. Normalisations correspond
to :mod:`esmpy` constants :attr:`~esmpy.api.constants.NormType.FRACAREA` and
:attr:`~esmpy.api.constants.NormType.DSTAREA`.
norm_type : :class:`Constants.NormType`
Either ``Constants.NormType.FRACAREA`` or ``Constants.NormType.DSTAREA``.
Determines the type of normalisation applied to the weights.
mdtol : float, default=1
A number between 0 and 1 describing the missing data tolerance.
Depending on the value of ``mdtol``, if a cell in the target grid is not
Expand All @@ -173,6 +161,9 @@ def regrid(self, src_array, norm_type="fracarea", mdtol=1):
An array whose shape is compatible with ``self.tgt``.
"""
# Sets default value, as this can't be done with class attributes within method call
norm_type = check_norm(norm_type)

array_shape = src_array.shape
main_shape = array_shape[-self.src.dims :]
if main_shape != self.src.shape:
Expand All @@ -190,12 +181,10 @@ def regrid(self, src_array, norm_type="fracarea", mdtol=1):
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size])
if norm_type == "fracarea":
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == "dstarea":
elif norm_type == Constants.NormType.DSTAREA:
pass
else:
raise ValueError(f'Normalisation type "{norm_type}" is not supported')
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
Expand Down
9 changes: 7 additions & 2 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy.sparse

import esmf_regrid
from esmf_regrid import check_method, Constants
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
Expand Down Expand Up @@ -109,7 +110,8 @@ def _standard_grid_cube(grid, name):
)
raise TypeError(msg)

method = rg.method
method = str(check_method(rg.method).name)

resolution = rg.resolution

weight_matrix = rg.regridder.weight_matrix
Expand Down Expand Up @@ -203,7 +205,10 @@ def load_regridder(filename):

# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method = weights_cube.attributes.get(METHOD, "conservative")
method = getattr(
Constants.Method, weights_cube.attributes.get(METHOD, "CONSERVATIVE")
)

resolution = weights_cube.attributes.get(RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)
Expand Down
49 changes: 19 additions & 30 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from iris.experimental.ugrid import Mesh

from esmf_regrid import check_method, Constants
from esmf_regrid.schemes import (
_ESMFRegridder,
_get_mask,
Expand All @@ -18,7 +19,7 @@ def regrid_unstructured_to_rectilinear(
src_cube,
grid_cube,
mdtol=0,
method="conservative",
method=Constants.Method.CONSERVATIVE,
tgt_resolution=None,
use_src_mask=False,
use_tgt_mask=False,
Expand Down Expand Up @@ -59,11 +60,8 @@ def regrid_unstructured_to_rectilinear(
target cell. ``mdtol=0`` means no missing data is tolerated while ``mdtol=1``
will mean the resulting element will be masked if and only if all the
overlapping cells of ``src_cube`` are masked.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
tgt_resolution : int, optional
If present, represents the amount of latitude slices per cell
given to ESMF for calculation.
Expand All @@ -88,6 +86,7 @@ def regrid_unstructured_to_rectilinear(
raise ValueError("src_cube has no mesh.")
src_mask = _get_mask(src_cube, use_src_mask)
tgt_mask = _get_mask(grid_cube, use_tgt_mask)
method = check_method(method)

regrid_info = _regrid_unstructured_to_rectilinear__prepare(
src_cube,
Expand All @@ -109,7 +108,7 @@ def __init__(
src,
tgt,
mdtol=None,
method="conservative",
method=Constants.Method.CONSERVATIVE,
precomputed_weights=None,
tgt_resolution=None,
use_src_mask=False,
Expand All @@ -131,11 +130,8 @@ def __init__(
``mdtol=1`` will mean the resulting element will be masked if and only
if all the contributing elements of data are masked. Defaults to 1
for conservative regregridding and 0 for bilinear regridding.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy`
methods :attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
Expand Down Expand Up @@ -185,7 +181,7 @@ def regrid_rectilinear_to_unstructured(
src_cube,
mesh_cube,
mdtol=0,
method="conservative",
method=Constants.Method.CONSERVATIVE,
src_resolution=None,
use_src_mask=False,
use_tgt_mask=False,
Expand Down Expand Up @@ -230,11 +226,8 @@ def regrid_rectilinear_to_unstructured(
target cell. ``mdtol=0`` means no missing data is tolerated while ``mdtol=1``
will mean the resulting element will be masked if and only if all the
overlapping cells of the ``src_cube`` are masked.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
src_resolution : int, optional
If present, represents the amount of latitude slices per cell
given to ESMF for calculation.
Expand All @@ -259,6 +252,7 @@ def regrid_rectilinear_to_unstructured(
raise ValueError("mesh_cube has no mesh.")
src_mask = _get_mask(src_cube, use_src_mask)
tgt_mask = _get_mask(mesh_cube, use_tgt_mask)
method = check_method(method)

regrid_info = _regrid_rectilinear_to_unstructured__prepare(
src_cube,
Expand All @@ -280,7 +274,7 @@ def __init__(
src,
tgt,
mdtol=None,
method="conservative",
method=Constants.Method.CONSERVATIVE,
precomputed_weights=None,
src_resolution=None,
use_src_mask=False,
Expand All @@ -304,11 +298,8 @@ def __init__(
``mdtol=1`` will mean the resulting element will be masked if and only
if all the contributing elements of data are masked. Defaults to 1
for conservative regregridding and 0 for bilinear regridding.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy`
methods :attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
Expand Down Expand Up @@ -361,7 +352,7 @@ def regrid_unstructured_to_unstructured(
src_mesh_cube,
tgt_mesh_cube,
mdtol=0,
method="conservative",
method=Constants.Method.CONSERVATIVE,
use_src_mask=False,
use_tgt_mask=False,
):
Expand Down Expand Up @@ -391,11 +382,8 @@ def regrid_unstructured_to_unstructured(
target cell. ``mdtol=0`` means no missing data is tolerated while ``mdtol=1``
will mean the resulting element will be masked if and only if all the
overlapping cells of the ``src_cube`` are masked.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, default=False
Either an array representing the cells in the source to ignore, or else
a boolean value. If True, this array is taken from the mask on the data
Expand All @@ -413,6 +401,7 @@ def regrid_unstructured_to_unstructured(
A new :class:`~iris.cube.Cube` instance.
"""
method = check_method(method)
if tgt_mesh_cube.mesh is None:
raise ValueError("mesh_cube has no mesh.")
src_mask = _get_mask(src_mesh_cube, use_src_mask)
Expand Down
Loading

0 comments on commit e83b7be

Please sign in to comment.