Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for planar periodic meshes #23

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mosaic/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
"mpasli.AIS8to30": {
"lcrc_path": "inputdata/glc/mpasli/mpas.ais8to30km/ais_8to30km.20221027.nc",
"sha256_hash": "sha256:932a1989ff8e51223413ef3ff0056d6737a1fc7f53e440359884a567a93413d2"
}
},

"doubly_periodic_4x4": {
"lcrc_path": "mpas_standalonedata/mpas-ocean/mesh_database/doubly_periodic_1920km_7680x7680km.151124.nc",
"sha256_hash": "sha256:5409d760845fb682ec56e30d9c6aa6dfe16b5d0e0e74f5da989cdaddbf4303c7"
},
}

# create a parsable registry for pooch from human friendly one
Expand Down Expand Up @@ -55,8 +60,9 @@ def open_dataset(

* ``"QU.960km"`` : Quasi-uniform spherical mesh, with approximately 960km horizontal resolution
* ``"QU.240km"`` : Quasi-uniform spherical mesh, with approximately 240km horizontal resolution
* ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavaties
* ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavities
* ``"mpasli.AIS8to30"`` : 8-30 km resolution planar non-periodic MALI mesh of Antarctica
* ``"doubly_periodic_4x4"``: Doubly periodic planar mesh that is four cells wide in both the x and y dimensions.

Parameters
----------
Expand Down
213 changes: 154 additions & 59 deletions mosaic/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import cached_property
from typing import Literal

import cartopy.crs as ccrs
import numpy as np
import xarray as xr
from cartopy.crs import CRS
Expand All @@ -19,6 +21,13 @@
"verticesOnCell",
"edgesOnVertex"]

SUPPORTED_PROJECTIONS = (ccrs._RectangularProjection,
ccrs._WarpedRectangularProjection,
ccrs.Stereographic,
ccrs.Mercator,
ccrs._CylindricalProjection,
ccrs.InterruptedGoodeHomolosine)


def attr_to_bool(attr: str):
""" Format attribute strings and return a boolean value """
Expand Down Expand Up @@ -105,6 +114,8 @@ def __init__(
#: ``projection`` kwargs are provided.
self.transform = transform

#: Boolean whether parent mesh is (planar) periodic in at least one dim
self.is_periodic = attr_to_bool(mesh_ds.is_periodic)
#: Boolean whether parent mesh is spherical
self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere)

Expand All @@ -120,6 +131,10 @@ def __init__(
# if both a projection and transform were provided to the constructor
self.projection = projection

# method ensures is periodic, avoiding AttributeErrors if non-periodic
self.x_period = self._parse_period(mesh_ds, "x")
self.y_period = self._parse_period(mesh_ds, "y")

def _create_minimal_dataset(self, ds: Dataset) -> Dataset:
"""
Create a xarray.Dataset that contains the minimal subset of
Expand Down Expand Up @@ -190,6 +205,13 @@ def projection(self) -> CRS:

@projection.setter
def projection(self, projection: CRS) -> None:
# We don't support all map projections for spherical meshes, yet...
if (projection is not None and self.is_spherical and
not isinstance(projection, SUPPORTED_PROJECTIONS)):

raise ValueError(f"Invalid projection: {type(projection).__name__}"
f" is not supported - consider using "
f"a rectangular projection.")
# Issue warning if changing the projection after initialization
# TODO: Add heuristic size (i.e. ``self.ds.nbytes``) above which the
# warning is raised
Expand Down Expand Up @@ -228,6 +250,57 @@ def latlon(self, value) -> None:

self._latlon = value

def _parse_period(self, ds, dim: Literal["x", "y"]):
""" Parse period attribute, return None for non-periodic meshes """

attr = f"{dim}_period"
try:
period = float(ds.attrs[attr])
except KeyError:
period = None

# in the off chance mesh is periodic but does not have period attribute
if self.is_periodic and attr not in ds.attrs:
raise AttributeError((f"Mesh file: \"{ds.encoding['source']}\""
f"does not have attribute `{attr}` despite"
f"being a planar periodic mesh."))
if period == 0.0:
return None
else:
return period

@property
def x_period(self) -> float | None:
""" Period along x-dimension, is ``None`` for non-periodic meshes """
return self._x_period

@x_period.setter
def x_period(self, value) -> None:
# needed to avoid AttributeError for non-periodic meshes
if not (self.is_periodic and self.is_spherical):
self._x_period = None
if not self.is_periodic and self.is_spherical:
x_limits = self.projection.x_limits
self._x_period = np.abs(x_limits[1] - x_limits[0])
else:
self._x_period = value

@property
def y_period(self) -> float | None:
""" Period along y-dimension, is ``None`` for non-periodic meshes """
return self._y_period

@y_period.setter
def y_period(self, value) -> None:
# needed to avoid AttributeError for non-periodic meshes
if not (self.is_periodic and self.is_spherical):
self._y_period = None
if not self.is_periodic and self.is_spherical:
y_limits = self.projection.y_limits
self._y_period = np.abs(y_limits[1] - y_limits[0])
else:
self._y_period = value

@cached_property
def cell_patches(self) -> ndarray:
""":py:class:`~numpy.ndarray` of patch coordinates for cell centered
Expand All @@ -244,7 +317,13 @@ def cell_patches(self) -> ndarray:
patch. Nodes are ordered counter clockwise around the cell center.
"""
patches = _compute_cell_patches(self.ds)
patches = self._fix_antimeridian(patches, "Cell")
patches = self._wrap_patches(patches, "Cell")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._cell_pole_mask = self._compute_pole_mask("Cell")

return patches

@cached_property
Expand All @@ -263,7 +342,13 @@ def edge_patches(self) -> ndarray:
corresponding node will be collapsed to the edge coordinate.
"""
patches = _compute_edge_patches(self.ds)
patches = self._fix_antimeridian(patches, "Edge")
patches = self._wrap_patches(patches, "Edge")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._edge_pole_mask = self._compute_pole_mask("Edge")

return patches

@cached_property
Expand All @@ -289,7 +374,13 @@ def vertex_patches(self) -> ndarray:
position.
"""
patches = _compute_vertex_patches(self.ds)
patches = self._fix_antimeridian(patches, "Vertex")
patches = self._wrap_patches(patches, "Vertex")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._vertex_pole_mask = self._compute_pole_mask("Vertex")

return patches

def _transform_coordinates(self, projection, transform):
Expand All @@ -304,70 +395,74 @@ def _transform_coordinates(self, projection, transform):
self.ds[f"x{loc}"].values = transformed_coords[:, 0]
self.ds[f"y{loc}"].values = transformed_coords[:, 1]

def _fix_antimeridian(self, patches, loc, projection=None):
"""Correct vertices of patches that cross the antimeridian.

NOTE: Can this be a decorator?
def _wrap_patches(self, patches, loc):
"""Wrap patches for spherical and planar-periodic meshes
"""
# coordinate arrays are transformed at initalization, so using the
# transform size limit, not the projection
if not projection:
projection = self.projection

# should be able to come up with a default size limit here, or maybe
# it's already an attribute(?) Should also factor in a precomputed
# axis period, as set in the attributes of the input dataset
if projection:
# convert to numpy array to that broadcasting below will work
x_center = np.array(self.ds[f"x{loc}"])

# get distance b/w the center and vertices of the patches
# NOTE: using data from masked patches array so that we compute
# mask only corresponds to patches that cross the boundary,
# (i.e. NOT a mask of all invalid cells). May need to be
# carefull about the fillvalue depending on the transform
half_distance = x_center[:, np.newaxis] - patches[..., 0].data

# get the size limit of the projection;
size_limit = np.abs(projection.x_limits[1] -
projection.x_limits[0]) / (2 * np.sqrt(2))

# left and right mask, with same number of dims as the patches
l_mask = (half_distance > size_limit)[..., np.newaxis]
r_mask = (half_distance < -size_limit)[..., np.newaxis]

def _find_boundary_patches(patches, loc, coord, period):
"""
# Old approach masks out all patches that cross the antimeridian.
# This is unnessarily restrictive. New approach corrects
# the x-coordinates of vertices that lie outside the projections
# bounds, which isn't perfect either

patches.mask |= l_mask
patches.mask |= r_mask
Find the patches that cross the periodic boundary and what
direction they cross the boundary (i.e. their ``sign``). This
method assumes the patch centroids are not periodic
"""
# get axis index we are inquiring over
axis = 0 if coord == "x" else 1
# get requested coordinate of patch centroids
center = self.ds[f"{coord}{loc.title()}"].values.reshape(-1, 1)
# get difference b/w centroid and nodes of patches
diff = patches[..., axis] - center

mask = np.abs(diff) > np.abs(period) / (2. * np.sqrt(2.))
sign = np.sign(diff)

return mask, sign

def _wrap_1D(patches, mask, sign, axis, period):
"""Correct patch periodicity along a single dimension"""
patches[..., axis][mask] -= np.sign(sign[mask]) * period

# TODO: clip spherical wrapped patches to projection limits
return patches

l_boundary_mask = ~np.any(l_mask, axis=1) | l_mask[..., 0]
r_boundary_mask = ~np.any(r_mask, axis=1) | r_mask[..., 0]
# get valid half distances for the patches that cross boundary
l_offset = np.ma.MaskedArray(half_distance, l_boundary_mask)
r_offset = np.ma.MaskedArray(half_distance, r_boundary_mask)

# For vertices that cross the antimeridian reset the x-coordinate
# of invalid vertex to be the center of the patch plus the
# mean valid half distance.
#
# NOTE: this only fixes patches on the side of plot where they
# cross the antimeridian, leaving an empty zipper like pattern
# mirrored over the y-axis.
patches[..., 0] = np.ma.where(
~l_mask[..., 0], patches[..., 0],
x_center[:, np.newaxis] + l_offset.mean(1)[..., np.newaxis])
patches[..., 0] = np.ma.where(
~r_mask[..., 0], patches[..., 0],
x_center[:, np.newaxis] + r_offset.mean(1)[..., np.newaxis])
# Stereographic projections do not need wrapping, so exit early
if isinstance(self.projection, ccrs.Stereographic):
return patches

if self.x_period:
# find the patch that are periodic in x-direction
x_mask, x_sign = _find_boundary_patches(
patches, loc, "x", self.x_period
)

if np.any(x_mask):
# using the sign of the difference correct patches x coordinate
patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period)

if self.y_period:
# find the patch that are periodic in y-direction
y_mask, y_sign = _find_boundary_patches(
patches, loc, "y", self.y_period
)

if np.any(x_mask):
# using the sign of the difference correct patches y coordinate
patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period)

return patches

def _compute_pole_mask(self, loc) -> ndarray:
""" """
limits = self.projection.y_limits
centers = self.ds[f"y{loc.title()}"].values

# TODO: determine threshold for ``isclose`` computation
at_pole = np.any(
np.isclose(centers.reshape(-1, 1), limits, rtol=1e-2), axis=1
)
past_pole = np.abs(centers) > np.abs(limits[1])

return (at_pole | past_pole)


def _compute_cell_patches(ds: Dataset) -> ndarray:
"""Create cell patches (i.e. Primary cells) for an MPAS mesh."""
Expand Down
7 changes: 7 additions & 0 deletions mosaic/polypcolor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from cartopy.mpl.geoaxes import GeoAxes
from matplotlib.axes import Axes
from matplotlib.collections import PolyCollection
Expand Down Expand Up @@ -48,12 +49,18 @@ def polypcolor(

if "nCells" in c.dims:
verts = descriptor.cell_patches
if descriptor.projection and np.any(descriptor._cell_pole_mask):
c = c.where(~descriptor._cell_pole_mask, np.nan)

elif "nEdges" in c.dims:
verts = descriptor.edge_patches
if descriptor.projection and np.any(descriptor._edge_pole_mask):
c = c.where(~descriptor._edge_pole_mask, np.nan)

elif "nVertices" in c.dims:
verts = descriptor.vertex_patches
if descriptor.projection and np.any(descriptor._vertex_pole_mask):
c = c.where(~descriptor._vertex_pole_mask, np.nan)

transform = descriptor.transform

Expand Down
Loading