Skip to content

Commit

Permalink
Use planar periodic functions for spherical wrapping.
Browse files Browse the repository at this point in the history
WIP: There's debugging that's needed for vertex patches
  • Loading branch information
andrewdnolan committed Dec 22, 2024
1 parent df0de3b commit 94a20c9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 82 deletions.
157 changes: 75 additions & 82 deletions mosaic/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import cached_property
from typing import Literal

import cartopy.crs as ccrs
import numpy as np
import xarray as xr
from cartopy.crs import CRS
Expand All @@ -20,6 +21,13 @@
"verticesOnCell",
"edgesOnVertex"]

SUPPORTED_PROJECTIONS = (ccrs._RectangularProjection,
ccrs._WarpedRectangularProjection,
ccrs.Stereographic,
ccrs.Mercator,
ccrs._CylindricalProjection,
ccrs.InterruptedGoodeHomolosine)


def attr_to_bool(attr: str):
""" Format attribute strings and return a boolean value """
Expand Down Expand Up @@ -111,10 +119,6 @@ def __init__(
#: Boolean whether parent mesh is spherical
self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere)

# method ensures is periodic, avoiding AttributeErrors if non-periodic
self.x_period = self._parse_period(mesh_ds, "x")
self.y_period = self._parse_period(mesh_ds, "y")

# calls attribute setter method
self.latlon = use_latlon

Expand All @@ -127,6 +131,10 @@ def __init__(
# if both a projection and transform were provided to the constructor
self.projection = projection

# method ensures is periodic, avoiding AttributeErrors if non-periodic
self.x_period = self._parse_period(mesh_ds, "x")
self.y_period = self._parse_period(mesh_ds, "y")

def _create_minimal_dataset(self, ds: Dataset) -> Dataset:
"""
Create a xarray.Dataset that contains the minimal subset of
Expand Down Expand Up @@ -197,6 +205,13 @@ def projection(self) -> CRS:

@projection.setter
def projection(self, projection: CRS) -> None:
# We don't support all map projections for spherical meshes, yet...
if (projection is not None and self.is_spherical and
not isinstance(projection, SUPPORTED_PROJECTIONS)):

raise ValueError(f"Invalid projection: {type(projection).__name__}"
f" is not supported - consider using "
f"a rectangular projection.")
# Issue warning if changing the projection after initialization
# TODO: Add heuristic size (i.e. ``self.ds.nbytes``) above which the
# warning is raised
Expand Down Expand Up @@ -262,8 +277,11 @@ def x_period(self) -> float | None:
@x_period.setter
def x_period(self, value) -> None:
# needed to avoid AttributeError for non-periodic meshes
if not self.is_periodic:
if not (self.is_periodic and self.is_spherical):
self._x_period = None
if not self.is_periodic and self.is_spherical:
x_limits = self.projection.x_limits
self._x_period = np.abs(x_limits[1] - x_limits[0])
else:
self._x_period = value

Expand All @@ -275,8 +293,11 @@ def y_period(self) -> float | None:
@y_period.setter
def y_period(self, value) -> None:
# needed to avoid AttributeError for non-periodic meshes
if not self.is_periodic:
if not (self.is_periodic and self.is_spherical):
self._y_period = None
if not self.is_periodic and self.is_spherical:
y_limits = self.projection.y_limits
self._y_period = np.abs(y_limits[1] - y_limits[0])
else:
self._y_period = value

Expand All @@ -297,6 +318,12 @@ def cell_patches(self) -> ndarray:
"""
patches = _compute_cell_patches(self.ds)
patches = self._wrap_patches(patches, "Cell")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._cell_pole_mask = self._compute_pole_mask("Cell")

return patches

@cached_property
Expand All @@ -316,6 +343,12 @@ def edge_patches(self) -> ndarray:
"""
patches = _compute_edge_patches(self.ds)
patches = self._wrap_patches(patches, "Edge")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._edge_pole_mask = self._compute_pole_mask("Edge")

return patches

@cached_property
Expand All @@ -342,6 +375,12 @@ def vertex_patches(self) -> ndarray:
"""
patches = _compute_vertex_patches(self.ds)
patches = self._wrap_patches(patches, "Vertex")

# cartopy doesn't handle nans in patches, so store a mask of the
# invalid patches to set the dataarray at those locations to nan.
if self.projection:
self._vertex_pole_mask = self._compute_pole_mask("Vertex")

return patches

def _transform_coordinates(self, projection, transform):
Expand All @@ -358,10 +397,9 @@ def _transform_coordinates(self, projection, transform):

def _wrap_patches(self, patches, loc):
"""Wrap patches for spherical and planar-periodic meshes
"""

def _find_boundary_patches(patches, loc, coord):
def _find_boundary_patches(patches, loc, coord, period):
"""
Find the patches that cross the periodic boundary and what
direction they cross the boundary (i.e. their ``sign``). This
Expand All @@ -374,10 +412,6 @@ def _find_boundary_patches(patches, loc, coord):
# get difference b/w centroid and nodes of patches
diff = patches[..., axis] - center

#
if self.__getattribute__(f"{coord}_period"):
period = self.__getattribute__(f"{coord}_period")

mask = np.abs(diff) > np.abs(period) / (2. * np.sqrt(2.))
sign = np.sign(diff)

Expand All @@ -386,89 +420,48 @@ def _find_boundary_patches(patches, loc, coord):
def _wrap_1D(patches, mask, sign, axis, period):
"""Correct patch periodicity along a single dimension"""
patches[..., axis][mask] -= np.sign(sign[mask]) * period

# TODO: clip spherical wrapped patches to projection limits
return patches

# Stereographic projections do not need wrapping, so exit early
if isinstance(self.projection, ccrs.Stereographic):
return patches

if self.x_period:
# find the patch that are periodic in x-direction
x_mask, x_sign = _find_boundary_patches(patches, loc, "x")
# using the sign of the difference correct patches x coordinate
patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period)
x_mask, x_sign = _find_boundary_patches(
patches, loc, "x", self.x_period
)

if np.any(x_mask):
# using the sign of the difference correct patches x coordinate
patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period)

if self.y_period:
# find the patch that are periodic in y-direction
y_mask, y_sign = _find_boundary_patches(patches, loc, "y")
# using the sign of the difference correct patches y coordinate
patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period)
y_mask, y_sign = _find_boundary_patches(
patches, loc, "y", self.y_period
)

if self.is_spherical:
# call current spherical wrapping function for now
patches = self._fix_antimeridian(patches, loc)
if np.any(x_mask):
# using the sign of the difference correct patches y coordinate
patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period)

return patches

def _fix_antimeridian(self, patches, loc, projection=None):
"""Correct vertices of patches that cross the antimeridian.
def _compute_pole_mask(self, loc) -> ndarray:
""" """
limits = self.projection.y_limits
centers = self.ds[f"y{loc.title()}"].values

NOTE: Can this be a decorator?
"""
# coordinate arrays are transformed at initalization, so using the
# transform size limit, not the projection
if not projection:
projection = self.projection

# should be able to come up with a default size limit here, or maybe
# it's already an attribute(?) Should also factor in a precomputed
# axis period, as set in the attributes of the input dataset
if projection:
# convert to numpy array to that broadcasting below will work
x_center = np.array(self.ds[f"x{loc}"])

# get distance b/w the center and vertices of the patches
# NOTE: using data from masked patches array so that we compute
# mask only corresponds to patches that cross the boundary,
# (i.e. NOT a mask of all invalid cells). May need to be
# carefull about the fillvalue depending on the transform
half_distance = x_center[:, np.newaxis] - patches[..., 0].data

# get the size limit of the projection;
size_limit = np.abs(projection.x_limits[1] -
projection.x_limits[0]) / (2 * np.sqrt(2))

# left and right mask, with same number of dims as the patches
l_mask = (half_distance > size_limit)[..., np.newaxis]
r_mask = (half_distance < -size_limit)[..., np.newaxis]

"""
# Old approach masks out all patches that cross the antimeridian.
# This is unnessarily restrictive. New approach corrects
# the x-coordinates of vertices that lie outside the projections
# bounds, which isn't perfect either
# TODO: determine threshold for ``isclose`` computation
at_pole = np.any(
np.isclose(centers.reshape(-1, 1), limits, rtol=1e-2), axis=1
)
past_pole = np.abs(centers) > np.abs(limits[1])

patches.mask |= l_mask
patches.mask |= r_mask
"""

l_boundary_mask = ~np.any(l_mask, axis=1) | l_mask[..., 0]
r_boundary_mask = ~np.any(r_mask, axis=1) | r_mask[..., 0]
# get valid half distances for the patches that cross boundary
l_offset = np.ma.MaskedArray(half_distance, l_boundary_mask)
r_offset = np.ma.MaskedArray(half_distance, r_boundary_mask)

# For vertices that cross the antimeridian reset the x-coordinate
# of invalid vertex to be the center of the patch plus the
# mean valid half distance.
#
# NOTE: this only fixes patches on the side of plot where they
# cross the antimeridian, leaving an empty zipper like pattern
# mirrored over the y-axis.
patches[..., 0] = np.ma.where(
~l_mask[..., 0], patches[..., 0],
x_center[:, np.newaxis] + l_offset.mean(1)[..., np.newaxis])
patches[..., 0] = np.ma.where(
~r_mask[..., 0], patches[..., 0],
x_center[:, np.newaxis] + r_offset.mean(1)[..., np.newaxis])

return patches
return (at_pole | past_pole)


def _compute_cell_patches(ds: Dataset) -> ndarray:
Expand Down
7 changes: 7 additions & 0 deletions mosaic/polypcolor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from cartopy.mpl.geoaxes import GeoAxes
from matplotlib.axes import Axes
from matplotlib.collections import PolyCollection
Expand Down Expand Up @@ -48,12 +49,18 @@ def polypcolor(

if "nCells" in c.dims:
verts = descriptor.cell_patches
if descriptor.projection and np.any(descriptor._cell_pole_mask):
c = c.where(~descriptor._cell_pole_mask, np.nan)

elif "nEdges" in c.dims:
verts = descriptor.edge_patches
if descriptor.projection and np.any(descriptor._edge_pole_mask):
c = c.where(~descriptor._edge_pole_mask, np.nan)

elif "nVertices" in c.dims:
verts = descriptor.vertex_patches
if descriptor.projection and np.any(descriptor._vertex_pole_mask):
c = c.where(~descriptor._vertex_pole_mask, np.nan)

transform = descriptor.transform

Expand Down

0 comments on commit 94a20c9

Please sign in to comment.