From 38bc29386ab2e78ab414a85117ce1bbcc2d6be68 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 19 Dec 2024 12:55:43 -0700 Subject: [PATCH 1/5] Parse attributes related to periodicity from mesh --- mosaic/descriptor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mosaic/descriptor.py b/mosaic/descriptor.py index 5cfed64..95f037f 100644 --- a/mosaic/descriptor.py +++ b/mosaic/descriptor.py @@ -31,6 +31,14 @@ def attr_to_bool(attr: str): raise ValueError("f{attr} was unable to be parsed as YES/NO") +def parse_period(period): + """ Parse period attribute, return None if period is zero """ + if float(period) == 0.0: + return None + else: + return float(period) + + class Descriptor: """Data structure describing unstructured MPAS meshes. @@ -105,8 +113,14 @@ def __init__( #: ``projection`` kwargs are provided. self.transform = transform + #: Boolean whether parent mesh is (planar) periodic in at least one dim + self.is_periodic = attr_to_bool(mesh_ds.is_periodic) #: Boolean whether parent mesh is spherical self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere) + #: Period along x-dimension, is ``None`` for non-periodic meshes + self.x_period = parse_period(mesh_ds.x_period) + #: Period along y-dimension, is ``None`` for non-periodic meshes + self.y_period = parse_period(mesh_ds.y_period) # calls attribute setter method self.latlon = use_latlon From d0fa8876e5538ffa59b998117c9e40ce8f7639aa Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 19 Dec 2024 15:14:12 -0700 Subject: [PATCH 2/5] Add doubly periodic dataset for testing --- mosaic/datasets.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mosaic/datasets.py b/mosaic/datasets.py index de8936e..cf20409 100644 --- a/mosaic/datasets.py +++ b/mosaic/datasets.py @@ -26,7 +26,12 @@ "mpasli.AIS8to30": { "lcrc_path": "inputdata/glc/mpasli/mpas.ais8to30km/ais_8to30km.20221027.nc", "sha256_hash": "sha256:932a1989ff8e51223413ef3ff0056d6737a1fc7f53e440359884a567a93413d2" - } + }, + + "doubly_periodic_4x4": { + "lcrc_path": "mpas_standalonedata/mpas-ocean/mesh_database/doubly_periodic_1920km_7680x7680km.151124.nc", + "sha256_hash": "sha256:5409d760845fb682ec56e30d9c6aa6dfe16b5d0e0e74f5da989cdaddbf4303c7" + }, } # create a parsable registry for pooch from human friendly one @@ -55,8 +60,9 @@ def open_dataset( * ``"QU.960km"`` : Quasi-uniform spherical mesh, with approximately 960km horizontal resolution * ``"QU.240km"`` : Quasi-uniform spherical mesh, with approximately 240km horizontal resolution - * ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavaties + * ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavities * ``"mpasli.AIS8to30"`` : 8-30 km resolution planar non-periodic MALI mesh of Antarctica + * ``"doubly_periodic_4x4"``: Doubly periodic planar mesh that is four cells wide in both the x and y dimensions. Parameters ---------- From 277c363b981004f862f0ab68e6cdaf814f86dad1 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Fri, 20 Dec 2024 13:53:22 -0700 Subject: [PATCH 3/5] Make x/y period properties to improve error handeling --- mosaic/descriptor.py | 62 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/mosaic/descriptor.py b/mosaic/descriptor.py index 95f037f..3811067 100644 --- a/mosaic/descriptor.py +++ b/mosaic/descriptor.py @@ -1,4 +1,5 @@ from functools import cached_property +from typing import Literal import numpy as np import xarray as xr @@ -31,14 +32,6 @@ def attr_to_bool(attr: str): raise ValueError("f{attr} was unable to be parsed as YES/NO") -def parse_period(period): - """ Parse period attribute, return None if period is zero """ - if float(period) == 0.0: - return None - else: - return float(period) - - class Descriptor: """Data structure describing unstructured MPAS meshes. @@ -117,10 +110,10 @@ def __init__( self.is_periodic = attr_to_bool(mesh_ds.is_periodic) #: Boolean whether parent mesh is spherical self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere) - #: Period along x-dimension, is ``None`` for non-periodic meshes - self.x_period = parse_period(mesh_ds.x_period) - #: Period along y-dimension, is ``None`` for non-periodic meshes - self.y_period = parse_period(mesh_ds.y_period) + + # method ensures is periodic, avoiding AttributeErrors if non-periodic + self.x_period = self._parse_period(mesh_ds, "x") + self.y_period = self._parse_period(mesh_ds, "y") # calls attribute setter method self.latlon = use_latlon @@ -242,6 +235,51 @@ def latlon(self, value) -> None: self._latlon = value + def _parse_period(self, ds, dim: Literal["x", "y"]): + """ Parse period attribute, return None for non-periodic meshes """ + + attr = f"{dim}_period" + try: + period = float(ds.attrs[attr]) + except KeyError: + period = None + + # in the off chance mesh is periodic but does not have period attribute + if self.is_periodic and attr not in ds.attrs: + raise AttributeError((f"Mesh file: \"{ds.encoding['source']}\"" + f"does not have attribute `{attr}` despite" + f"being a planar periodic mesh.")) + if period == 0.0: + return None + else: + return period + + @property + def x_period(self) -> float | None: + """ Period along x-dimension, is ``None`` for non-periodic meshes """ + return self._x_period + + @x_period.setter + def x_period(self, value) -> None: + # needed to avoid AttributeError for non-periodic meshes + if not self.is_periodic: + self._x_period = None + else: + self._x_period = value + + @property + def y_period(self) -> float | None: + """ Period along y-dimension, is ``None`` for non-periodic meshes """ + return self._y_period + + @y_period.setter + def y_period(self, value) -> None: + # needed to avoid AttributeError for non-periodic meshes + if not self.is_periodic: + self._y_period = None + else: + self._y_period = value + @cached_property def cell_patches(self) -> ndarray: """:py:class:`~numpy.ndarray` of patch coordinates for cell centered From df0de3b186d12a68f35a58f3b99ac64aed3e6c4e Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Fri, 20 Dec 2024 14:31:27 -0700 Subject: [PATCH 4/5] Add support for wrapping planar periodic meshes. Further testing is need to move the spherical wrapping into the new method used by planar periodic meshes. --- mosaic/descriptor.py | 56 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/mosaic/descriptor.py b/mosaic/descriptor.py index 3811067..5f3e770 100644 --- a/mosaic/descriptor.py +++ b/mosaic/descriptor.py @@ -296,7 +296,7 @@ def cell_patches(self) -> ndarray: patch. Nodes are ordered counter clockwise around the cell center. """ patches = _compute_cell_patches(self.ds) - patches = self._fix_antimeridian(patches, "Cell") + patches = self._wrap_patches(patches, "Cell") return patches @cached_property @@ -315,7 +315,7 @@ def edge_patches(self) -> ndarray: corresponding node will be collapsed to the edge coordinate. """ patches = _compute_edge_patches(self.ds) - patches = self._fix_antimeridian(patches, "Edge") + patches = self._wrap_patches(patches, "Edge") return patches @cached_property @@ -341,7 +341,7 @@ def vertex_patches(self) -> ndarray: position. """ patches = _compute_vertex_patches(self.ds) - patches = self._fix_antimeridian(patches, "Vertex") + patches = self._wrap_patches(patches, "Vertex") return patches def _transform_coordinates(self, projection, transform): @@ -356,6 +356,56 @@ def _transform_coordinates(self, projection, transform): self.ds[f"x{loc}"].values = transformed_coords[:, 0] self.ds[f"y{loc}"].values = transformed_coords[:, 1] + def _wrap_patches(self, patches, loc): + """Wrap patches for spherical and planar-periodic meshes + + """ + + def _find_boundary_patches(patches, loc, coord): + """ + Find the patches that cross the periodic boundary and what + direction they cross the boundary (i.e. their ``sign``). This + method assumes the patch centroids are not periodic + """ + # get axis index we are inquiring over + axis = 0 if coord == "x" else 1 + # get requested coordinate of patch centroids + center = self.ds[f"{coord}{loc.title()}"].values.reshape(-1, 1) + # get difference b/w centroid and nodes of patches + diff = patches[..., axis] - center + + # + if self.__getattribute__(f"{coord}_period"): + period = self.__getattribute__(f"{coord}_period") + + mask = np.abs(diff) > np.abs(period) / (2. * np.sqrt(2.)) + sign = np.sign(diff) + + return mask, sign + + def _wrap_1D(patches, mask, sign, axis, period): + """Correct patch periodicity along a single dimension""" + patches[..., axis][mask] -= np.sign(sign[mask]) * period + return patches + + if self.x_period: + # find the patch that are periodic in x-direction + x_mask, x_sign = _find_boundary_patches(patches, loc, "x") + # using the sign of the difference correct patches x coordinate + patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period) + + if self.y_period: + # find the patch that are periodic in y-direction + y_mask, y_sign = _find_boundary_patches(patches, loc, "y") + # using the sign of the difference correct patches y coordinate + patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period) + + if self.is_spherical: + # call current spherical wrapping function for now + patches = self._fix_antimeridian(patches, loc) + + return patches + def _fix_antimeridian(self, patches, loc, projection=None): """Correct vertices of patches that cross the antimeridian. From 94a20c9974d6163ef9a86f12eaaa8ed9aa9116bd Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Sat, 21 Dec 2024 22:39:53 -0700 Subject: [PATCH 5/5] Use planar periodic functions for spherical wrapping. WIP: There's debugging that's needed for vertex patches --- mosaic/descriptor.py | 157 +++++++++++++++++++++---------------------- mosaic/polypcolor.py | 7 ++ 2 files changed, 82 insertions(+), 82 deletions(-) diff --git a/mosaic/descriptor.py b/mosaic/descriptor.py index 5f3e770..1afe977 100644 --- a/mosaic/descriptor.py +++ b/mosaic/descriptor.py @@ -1,6 +1,7 @@ from functools import cached_property from typing import Literal +import cartopy.crs as ccrs import numpy as np import xarray as xr from cartopy.crs import CRS @@ -20,6 +21,13 @@ "verticesOnCell", "edgesOnVertex"] +SUPPORTED_PROJECTIONS = (ccrs._RectangularProjection, + ccrs._WarpedRectangularProjection, + ccrs.Stereographic, + ccrs.Mercator, + ccrs._CylindricalProjection, + ccrs.InterruptedGoodeHomolosine) + def attr_to_bool(attr: str): """ Format attribute strings and return a boolean value """ @@ -111,10 +119,6 @@ def __init__( #: Boolean whether parent mesh is spherical self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere) - # method ensures is periodic, avoiding AttributeErrors if non-periodic - self.x_period = self._parse_period(mesh_ds, "x") - self.y_period = self._parse_period(mesh_ds, "y") - # calls attribute setter method self.latlon = use_latlon @@ -127,6 +131,10 @@ def __init__( # if both a projection and transform were provided to the constructor self.projection = projection + # method ensures is periodic, avoiding AttributeErrors if non-periodic + self.x_period = self._parse_period(mesh_ds, "x") + self.y_period = self._parse_period(mesh_ds, "y") + def _create_minimal_dataset(self, ds: Dataset) -> Dataset: """ Create a xarray.Dataset that contains the minimal subset of @@ -197,6 +205,13 @@ def projection(self) -> CRS: @projection.setter def projection(self, projection: CRS) -> None: + # We don't support all map projections for spherical meshes, yet... + if (projection is not None and self.is_spherical and + not isinstance(projection, SUPPORTED_PROJECTIONS)): + + raise ValueError(f"Invalid projection: {type(projection).__name__}" + f" is not supported - consider using " + f"a rectangular projection.") # Issue warning if changing the projection after initialization # TODO: Add heuristic size (i.e. ``self.ds.nbytes``) above which the # warning is raised @@ -262,8 +277,11 @@ def x_period(self) -> float | None: @x_period.setter def x_period(self, value) -> None: # needed to avoid AttributeError for non-periodic meshes - if not self.is_periodic: + if not (self.is_periodic and self.is_spherical): self._x_period = None + if not self.is_periodic and self.is_spherical: + x_limits = self.projection.x_limits + self._x_period = np.abs(x_limits[1] - x_limits[0]) else: self._x_period = value @@ -275,8 +293,11 @@ def y_period(self) -> float | None: @y_period.setter def y_period(self, value) -> None: # needed to avoid AttributeError for non-periodic meshes - if not self.is_periodic: + if not (self.is_periodic and self.is_spherical): self._y_period = None + if not self.is_periodic and self.is_spherical: + y_limits = self.projection.y_limits + self._y_period = np.abs(y_limits[1] - y_limits[0]) else: self._y_period = value @@ -297,6 +318,12 @@ def cell_patches(self) -> ndarray: """ patches = _compute_cell_patches(self.ds) patches = self._wrap_patches(patches, "Cell") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._cell_pole_mask = self._compute_pole_mask("Cell") + return patches @cached_property @@ -316,6 +343,12 @@ def edge_patches(self) -> ndarray: """ patches = _compute_edge_patches(self.ds) patches = self._wrap_patches(patches, "Edge") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._edge_pole_mask = self._compute_pole_mask("Edge") + return patches @cached_property @@ -342,6 +375,12 @@ def vertex_patches(self) -> ndarray: """ patches = _compute_vertex_patches(self.ds) patches = self._wrap_patches(patches, "Vertex") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._vertex_pole_mask = self._compute_pole_mask("Vertex") + return patches def _transform_coordinates(self, projection, transform): @@ -358,10 +397,9 @@ def _transform_coordinates(self, projection, transform): def _wrap_patches(self, patches, loc): """Wrap patches for spherical and planar-periodic meshes - """ - def _find_boundary_patches(patches, loc, coord): + def _find_boundary_patches(patches, loc, coord, period): """ Find the patches that cross the periodic boundary and what direction they cross the boundary (i.e. their ``sign``). This @@ -374,10 +412,6 @@ def _find_boundary_patches(patches, loc, coord): # get difference b/w centroid and nodes of patches diff = patches[..., axis] - center - # - if self.__getattribute__(f"{coord}_period"): - period = self.__getattribute__(f"{coord}_period") - mask = np.abs(diff) > np.abs(period) / (2. * np.sqrt(2.)) sign = np.sign(diff) @@ -386,89 +420,48 @@ def _find_boundary_patches(patches, loc, coord): def _wrap_1D(patches, mask, sign, axis, period): """Correct patch periodicity along a single dimension""" patches[..., axis][mask] -= np.sign(sign[mask]) * period + + # TODO: clip spherical wrapped patches to projection limits + return patches + + # Stereographic projections do not need wrapping, so exit early + if isinstance(self.projection, ccrs.Stereographic): return patches if self.x_period: # find the patch that are periodic in x-direction - x_mask, x_sign = _find_boundary_patches(patches, loc, "x") - # using the sign of the difference correct patches x coordinate - patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period) + x_mask, x_sign = _find_boundary_patches( + patches, loc, "x", self.x_period + ) + + if np.any(x_mask): + # using the sign of the difference correct patches x coordinate + patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period) if self.y_period: # find the patch that are periodic in y-direction - y_mask, y_sign = _find_boundary_patches(patches, loc, "y") - # using the sign of the difference correct patches y coordinate - patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period) + y_mask, y_sign = _find_boundary_patches( + patches, loc, "y", self.y_period + ) - if self.is_spherical: - # call current spherical wrapping function for now - patches = self._fix_antimeridian(patches, loc) + if np.any(x_mask): + # using the sign of the difference correct patches y coordinate + patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period) return patches - def _fix_antimeridian(self, patches, loc, projection=None): - """Correct vertices of patches that cross the antimeridian. + def _compute_pole_mask(self, loc) -> ndarray: + """ """ + limits = self.projection.y_limits + centers = self.ds[f"y{loc.title()}"].values - NOTE: Can this be a decorator? - """ - # coordinate arrays are transformed at initalization, so using the - # transform size limit, not the projection - if not projection: - projection = self.projection - - # should be able to come up with a default size limit here, or maybe - # it's already an attribute(?) Should also factor in a precomputed - # axis period, as set in the attributes of the input dataset - if projection: - # convert to numpy array to that broadcasting below will work - x_center = np.array(self.ds[f"x{loc}"]) - - # get distance b/w the center and vertices of the patches - # NOTE: using data from masked patches array so that we compute - # mask only corresponds to patches that cross the boundary, - # (i.e. NOT a mask of all invalid cells). May need to be - # carefull about the fillvalue depending on the transform - half_distance = x_center[:, np.newaxis] - patches[..., 0].data - - # get the size limit of the projection; - size_limit = np.abs(projection.x_limits[1] - - projection.x_limits[0]) / (2 * np.sqrt(2)) - - # left and right mask, with same number of dims as the patches - l_mask = (half_distance > size_limit)[..., np.newaxis] - r_mask = (half_distance < -size_limit)[..., np.newaxis] - - """ - # Old approach masks out all patches that cross the antimeridian. - # This is unnessarily restrictive. New approach corrects - # the x-coordinates of vertices that lie outside the projections - # bounds, which isn't perfect either + # TODO: determine threshold for ``isclose`` computation + at_pole = np.any( + np.isclose(centers.reshape(-1, 1), limits, rtol=1e-2), axis=1 + ) + past_pole = np.abs(centers) > np.abs(limits[1]) - patches.mask |= l_mask - patches.mask |= r_mask - """ - - l_boundary_mask = ~np.any(l_mask, axis=1) | l_mask[..., 0] - r_boundary_mask = ~np.any(r_mask, axis=1) | r_mask[..., 0] - # get valid half distances for the patches that cross boundary - l_offset = np.ma.MaskedArray(half_distance, l_boundary_mask) - r_offset = np.ma.MaskedArray(half_distance, r_boundary_mask) - - # For vertices that cross the antimeridian reset the x-coordinate - # of invalid vertex to be the center of the patch plus the - # mean valid half distance. - # - # NOTE: this only fixes patches on the side of plot where they - # cross the antimeridian, leaving an empty zipper like pattern - # mirrored over the y-axis. - patches[..., 0] = np.ma.where( - ~l_mask[..., 0], patches[..., 0], - x_center[:, np.newaxis] + l_offset.mean(1)[..., np.newaxis]) - patches[..., 0] = np.ma.where( - ~r_mask[..., 0], patches[..., 0], - x_center[:, np.newaxis] + r_offset.mean(1)[..., np.newaxis]) - - return patches + return (at_pole | past_pole) def _compute_cell_patches(ds: Dataset) -> ndarray: diff --git a/mosaic/polypcolor.py b/mosaic/polypcolor.py index d17d9f1..17111e7 100644 --- a/mosaic/polypcolor.py +++ b/mosaic/polypcolor.py @@ -1,3 +1,4 @@ +import numpy as np from cartopy.mpl.geoaxes import GeoAxes from matplotlib.axes import Axes from matplotlib.collections import PolyCollection @@ -48,12 +49,18 @@ def polypcolor( if "nCells" in c.dims: verts = descriptor.cell_patches + if descriptor.projection and np.any(descriptor._cell_pole_mask): + c = c.where(~descriptor._cell_pole_mask, np.nan) elif "nEdges" in c.dims: verts = descriptor.edge_patches + if descriptor.projection and np.any(descriptor._edge_pole_mask): + c = c.where(~descriptor._edge_pole_mask, np.nan) elif "nVertices" in c.dims: verts = descriptor.vertex_patches + if descriptor.projection and np.any(descriptor._vertex_pole_mask): + c = c.where(~descriptor._vertex_pole_mask, np.nan) transform = descriptor.transform