Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated for enums #293

Merged
merged 53 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
87f2d8e
updated for enums
ESadek-MO Jul 30, 2023
49fb580
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2023
ec88bbf
tried another form of import
ESadek-MO Jul 30, 2023
e4d6b16
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Jul 30, 2023
9fc2d6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2023
f5b9994
moved import constants to bottom; it's something like this
ESadek-MO Jul 31, 2023
2fea929
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Jul 31, 2023
8c6079c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2023
651556b
I've fixed this before, but I can't remember how
ESadek-MO Aug 7, 2023
cff1b0a
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 7, 2023
90e46f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
1565c44
I added a note to my old one, hopefully this does the trick
ESadek-MO Aug 7, 2023
0ab4c83
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 7, 2023
5358245
Merge branch 'main' into test
ESadek-MO Aug 7, 2023
cc9d1b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
f63a848
testing this works
ESadek-MO Aug 8, 2023
65b80e6
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 8, 2023
ae5533b
hopefully fixed the type error problems
ESadek-MO Aug 9, 2023
d0e2310
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2023
9eb9c40
updated test to fix one error
ESadek-MO Aug 9, 2023
73ec742
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 9, 2023
ac9ab86
figuring it out
ESadek-MO Aug 15, 2023
9cccd85
fixed no method error
ESadek-MO Aug 17, 2023
cca232d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2023
9c3fe01
removed NotImplementedError
ESadek-MO Aug 21, 2023
a5e0add
removed commented out code
ESadek-MO Aug 21, 2023
42c87bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2023
a829446
missed one
ESadek-MO Aug 21, 2023
83136f5
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 21, 2023
e5b8e85
fixed flake8
ESadek-MO Aug 21, 2023
b6eda5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2023
a42f3ee
fixed flake8 again
ESadek-MO Aug 21, 2023
f005702
fixed flake8 merge error
ESadek-MO Aug 21, 2023
e373441
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2023
ea0a8ed
fixed flake8 the third
ESadek-MO Aug 21, 2023
8df0593
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Aug 21, 2023
091b11b
updated to include a checker function in constants, with necessary tests
ESadek-MO Oct 5, 2023
4f5ba2d
fixed merge conflicts
ESadek-MO Nov 24, 2023
16b340e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2023
dd3833f
fixed flake8 conflicts
ESadek-MO Nov 24, 2023
a020fb0
missed one: fixed another flake8 conflict
ESadek-MO Nov 24, 2023
10b241b
Merge branch 'test' of github.com:ESadek-MO/iris-esmf-regrid into test
ESadek-MO Nov 24, 2023
404aa26
more flake8
ESadek-MO Nov 24, 2023
525793b
flake8 2: electric imperativemoodaloo
ESadek-MO Nov 24, 2023
6833915
more flake8
ESadek-MO Nov 24, 2023
2891c40
fixed test failure
ESadek-MO Nov 24, 2023
0a08219
fixed test failure
ESadek-MO Nov 24, 2023
bac084c
fixed error messages to use enums
ESadek-MO Nov 24, 2023
18524aa
hasn't this been a journey. Fixed error message
ESadek-MO Nov 24, 2023
f1f1039
actioned review comments
ESadek-MO Nov 30, 2023
b3c8cf3
fixed precommit and error message
ESadek-MO Nov 30, 2023
cf617c3
corrected constants error message
ESadek-MO Nov 30, 2023
91be20c
corrected docstring
ESadek-MO Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions esmf_regrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
except ImportError:
raise exc

# constants needs to be above schemes, as it is used within
from .constants import Constants, check_method, check_norm
from .schemes import *


Expand Down
60 changes: 60 additions & 0 deletions esmf_regrid/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Holds all enums created for esmf-regrid."""

from enum import Enum

from . import esmpy


class Constants:
"""Encompassing class for best practice import."""
stephenworsley marked this conversation as resolved.
Show resolved Hide resolved

class Method(Enum):
"""holds enums for Method values."""

CONSERVATIVE = esmpy.RegridMethod.CONSERVE
BILINEAR = esmpy.RegridMethod.BILINEAR
NEAREST = esmpy.RegridMethod.NEAREST_STOD

class NormType(Enum):
"""holds enums for norm types."""

FRACAREA = esmpy.api.constants.NormType.FRACAREA
DSTAREA = esmpy.api.constants.NormType.DSTAREA


method_dict = {
"conservative": Constants.Method.CONSERVATIVE,
"bilinear": Constants.Method.BILINEAR,
"nearest": Constants.Method.NEAREST,
}

norm_dict = {
"fracarea": Constants.NormType.FRACAREA,
"dstarea": Constants.NormType.DSTAREA,
}


def check_method(method):
"""Check that method is a member of the `Constants.Method` enum or raise an error."""
if method in method_dict.keys():
result = method_dict[method]
elif method in method_dict.values():
result = method
else:
raise ValueError(
f"Method must be a member of `Constants.Method` enum, instead got {method}"
)
return result


def check_norm(norm):
"""Check that normtype is a member of the `Constants.NormType` enum or raise an error."""
if norm in norm_dict.keys():
result = norm_dict[norm]
elif norm in norm_dict.values():
result = norm
else:
raise ValueError(
f"NormType must be a member of `Constants.NormType` enum, instead got {norm}"
)
return result
47 changes: 18 additions & 29 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy.sparse

import esmf_regrid
from esmf_regrid import check_method, check_norm, Constants
from . import esmpy
from ._esmf_sdo import GridInfo, RefinedGridInfo

Expand Down Expand Up @@ -57,7 +58,9 @@ def _weights_dict_to_sparse_array(weights, shape, index_offsets):
class Regridder:
"""Regridder for directly interfacing with :mod:`esmpy`."""

def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
def __init__(
self, src, tgt, method=Constants.Method.CONSERVATIVE, precomputed_weights=None
):
"""
Create a regridder from descriptions of horizontal grids/meshes.

Expand All @@ -76,38 +79,25 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
Describes the target mesh/grid.
Data output by this regridder will be a :class:`numpy.ndarray` whose
shape is compatible with ``tgt``.
method : str
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE`,
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST_STOD` used to calculate weights.
method : :class:`Constants.Method`
The method to be used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
and ``precomputed_weights`` will be used as the regridding weights.
"""
self.src = src
self.tgt = tgt

if method == "conservative":
esmf_regrid_method = esmpy.RegridMethod.CONSERVE
elif method == "bilinear":
esmf_regrid_method = esmpy.RegridMethod.BILINEAR
elif method == "nearest":
esmf_regrid_method = esmpy.RegridMethod.NEAREST_STOD
else:
raise ValueError(
f"method must be either 'bilinear', 'conservative' or 'nearest', got '{method}'."
)
self.method = method
# type checks method
self.method = check_method(method)

self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = esmpy.__version__
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=esmf_regrid_method,
regrid_method=method.value,
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
Expand Down Expand Up @@ -144,19 +134,17 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
self.esmf_version = None
self.weight_matrix = precomputed_weights

def regrid(self, src_array, norm_type="fracarea", mdtol=1):
def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""
Perform regridding on an array of data.

Parameters
----------
src_array : :obj:`~numpy.typing.ArrayLike`
Array whose shape is compatible with ``self.src``
norm_type : str
Either ``fracarea`` or ``dstarea``, defaults to ``fracarea``. Determines the
type of normalisation applied to the weights. Normalisations correspond
to :mod:`esmpy` constants :attr:`~esmpy.api.constants.NormType.FRACAREA` and
:attr:`~esmpy.api.constants.NormType.DSTAREA`.
norm_type : :class:`Constants.NormType`
Either ``Constants.NormType.FRACAREA`` or ``Constants.NormType.DSTAREA``.
Determines the type of normalisation applied to the weights.
mdtol : float, default=1
A number between 0 and 1 describing the missing data tolerance.
Depending on the value of ``mdtol``, if a cell in the target grid is not
Expand All @@ -173,6 +161,9 @@ def regrid(self, src_array, norm_type="fracarea", mdtol=1):
An array whose shape is compatible with ``self.tgt``.

"""
# Sets default value, as this can't be done with class attributes within method call
norm_type = check_norm(norm_type)

array_shape = src_array.shape
main_shape = array_shape[-self.src.dims :]
if main_shape != self.src.shape:
Expand All @@ -190,12 +181,10 @@ def regrid(self, src_array, norm_type="fracarea", mdtol=1):
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size])
if norm_type == "fracarea":
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == "dstarea":
elif norm_type == Constants.NormType.DSTAREA:
pass
else:
raise ValueError(f'Normalisation type "{norm_type}" is not supported')
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
Expand Down
9 changes: 7 additions & 2 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy.sparse

import esmf_regrid
from esmf_regrid import check_method, Constants
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
Expand Down Expand Up @@ -109,7 +110,8 @@ def _standard_grid_cube(grid, name):
)
raise TypeError(msg)

method = rg.method
method = str(check_method(rg.method).name)

resolution = rg.resolution

weight_matrix = rg.regridder.weight_matrix
Expand Down Expand Up @@ -203,7 +205,10 @@ def load_regridder(filename):

# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method = weights_cube.attributes.get(METHOD, "conservative")
method = getattr(
Constants.Method, weights_cube.attributes.get(METHOD, "CONSERVATIVE")
)

resolution = weights_cube.attributes.get(RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)
Expand Down
40 changes: 16 additions & 24 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from iris.experimental.ugrid import Mesh

from esmf_regrid import check_method, Constants
from esmf_regrid.schemes import (
_ESMFRegridder,
_get_mask,
Expand All @@ -18,7 +19,7 @@ def regrid_unstructured_to_rectilinear(
src_cube,
grid_cube,
mdtol=0,
method="conservative",
method=Constants.Method.CONSERVATIVE,
tgt_resolution=None,
use_src_mask=False,
use_tgt_mask=False,
Expand Down Expand Up @@ -59,11 +60,8 @@ def regrid_unstructured_to_rectilinear(
target cell. ``mdtol=0`` means no missing data is tolerated while ``mdtol=1``
will mean the resulting element will be masked if and only if all the
overlapping cells of ``src_cube`` are masked.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
tgt_resolution : int, optional
If present, represents the amount of latitude slices per cell
given to ESMF for calculation.
Expand All @@ -88,6 +86,7 @@ def regrid_unstructured_to_rectilinear(
raise ValueError("src_cube has no mesh.")
src_mask = _get_mask(src_cube, use_src_mask)
tgt_mask = _get_mask(grid_cube, use_tgt_mask)
method = check_method(method)

regrid_info = _regrid_unstructured_to_rectilinear__prepare(
src_cube,
Expand All @@ -109,7 +108,7 @@ def __init__(
src,
tgt,
mdtol=None,
method="conservative",
method=Constants.Method.CONSERVATIVE,
precomputed_weights=None,
tgt_resolution=None,
use_src_mask=False,
Expand All @@ -131,11 +130,8 @@ def __init__(
``mdtol=1`` will mean the resulting element will be masked if and only
if all the contributing elements of data are masked. Defaults to 1
for conservative regregridding and 0 for bilinear regridding.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy`
methods :attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
Expand Down Expand Up @@ -185,7 +181,7 @@ def regrid_rectilinear_to_unstructured(
src_cube,
mesh_cube,
mdtol=0,
method="conservative",
method=Constants.Method.CONSERVATIVE,
src_resolution=None,
use_src_mask=False,
use_tgt_mask=False,
Expand Down Expand Up @@ -230,11 +226,8 @@ def regrid_rectilinear_to_unstructured(
target cell. ``mdtol=0`` means no missing data is tolerated while ``mdtol=1``
will mean the resulting element will be masked if and only if all the
overlapping cells of the ``src_cube`` are masked.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy` methods
:attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
src_resolution : int, optional
If present, represents the amount of latitude slices per cell
given to ESMF for calculation.
Expand All @@ -259,6 +252,7 @@ def regrid_rectilinear_to_unstructured(
raise ValueError("mesh_cube has no mesh.")
src_mask = _get_mask(src_cube, use_src_mask)
tgt_mask = _get_mask(mesh_cube, use_tgt_mask)
method = check_method(method)

regrid_info = _regrid_rectilinear_to_unstructured__prepare(
src_cube,
Expand All @@ -280,7 +274,7 @@ def __init__(
src,
tgt,
mdtol=None,
method="conservative",
method=Constants.Method.CONSERVATIVE,
precomputed_weights=None,
src_resolution=None,
use_src_mask=False,
Expand All @@ -304,11 +298,8 @@ def __init__(
``mdtol=1`` will mean the resulting element will be masked if and only
if all the contributing elements of data are masked. Defaults to 1
for conservative regregridding and 0 for bilinear regridding.
method : str, default="conservative"
Either "conservative", "bilinear" or "nearest". Corresponds to the :mod:`esmpy`
methods :attr:`~esmpy.api.constants.RegridMethod.CONSERVE` or
:attr:`~esmpy.api.constants.RegridMethod.BILINEAR` or
:attr:`~esmpy.api.constants.RegridMethod.NEAREST` used to calculate weights.
method : :class:`Constants.Method`, default=Constants.Method.CONSERVATIVE
The method used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
Expand Down Expand Up @@ -413,6 +404,7 @@ def regrid_unstructured_to_unstructured(
A new :class:`~iris.cube.Cube` instance.

"""
stephenworsley marked this conversation as resolved.
Show resolved Hide resolved
method = check_method(method)
if tgt_mesh_cube.mesh is None:
raise ValueError("mesh_cube has no mesh.")
src_mask = _get_mask(src_mesh_cube, use_src_mask)
Expand Down
Loading