diff --git a/mosaic/datasets.py b/mosaic/datasets.py index de8936e..cf20409 100644 --- a/mosaic/datasets.py +++ b/mosaic/datasets.py @@ -26,7 +26,12 @@ "mpasli.AIS8to30": { "lcrc_path": "inputdata/glc/mpasli/mpas.ais8to30km/ais_8to30km.20221027.nc", "sha256_hash": "sha256:932a1989ff8e51223413ef3ff0056d6737a1fc7f53e440359884a567a93413d2" - } + }, + + "doubly_periodic_4x4": { + "lcrc_path": "mpas_standalonedata/mpas-ocean/mesh_database/doubly_periodic_1920km_7680x7680km.151124.nc", + "sha256_hash": "sha256:5409d760845fb682ec56e30d9c6aa6dfe16b5d0e0e74f5da989cdaddbf4303c7" + }, } # create a parsable registry for pooch from human friendly one @@ -55,8 +60,9 @@ def open_dataset( * ``"QU.960km"`` : Quasi-uniform spherical mesh, with approximately 960km horizontal resolution * ``"QU.240km"`` : Quasi-uniform spherical mesh, with approximately 240km horizontal resolution - * ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavaties + * ``"mpaso.IcoswISC30E3r5"`` : Icosahedral 30 km MPAS-Ocean mesh with ice shelf cavities * ``"mpasli.AIS8to30"`` : 8-30 km resolution planar non-periodic MALI mesh of Antarctica + * ``"doubly_periodic_4x4"``: Doubly periodic planar mesh that is four cells wide in both the x and y dimensions. Parameters ---------- diff --git a/mosaic/descriptor.py b/mosaic/descriptor.py index 5cfed64..1afe977 100644 --- a/mosaic/descriptor.py +++ b/mosaic/descriptor.py @@ -1,5 +1,7 @@ from functools import cached_property +from typing import Literal +import cartopy.crs as ccrs import numpy as np import xarray as xr from cartopy.crs import CRS @@ -19,6 +21,13 @@ "verticesOnCell", "edgesOnVertex"] +SUPPORTED_PROJECTIONS = (ccrs._RectangularProjection, + ccrs._WarpedRectangularProjection, + ccrs.Stereographic, + ccrs.Mercator, + ccrs._CylindricalProjection, + ccrs.InterruptedGoodeHomolosine) + def attr_to_bool(attr: str): """ Format attribute strings and return a boolean value """ @@ -105,6 +114,8 @@ def __init__( #: ``projection`` kwargs are provided. self.transform = transform + #: Boolean whether parent mesh is (planar) periodic in at least one dim + self.is_periodic = attr_to_bool(mesh_ds.is_periodic) #: Boolean whether parent mesh is spherical self.is_spherical = attr_to_bool(mesh_ds.on_a_sphere) @@ -120,6 +131,10 @@ def __init__( # if both a projection and transform were provided to the constructor self.projection = projection + # method ensures is periodic, avoiding AttributeErrors if non-periodic + self.x_period = self._parse_period(mesh_ds, "x") + self.y_period = self._parse_period(mesh_ds, "y") + def _create_minimal_dataset(self, ds: Dataset) -> Dataset: """ Create a xarray.Dataset that contains the minimal subset of @@ -190,6 +205,13 @@ def projection(self) -> CRS: @projection.setter def projection(self, projection: CRS) -> None: + # We don't support all map projections for spherical meshes, yet... + if (projection is not None and self.is_spherical and + not isinstance(projection, SUPPORTED_PROJECTIONS)): + + raise ValueError(f"Invalid projection: {type(projection).__name__}" + f" is not supported - consider using " + f"a rectangular projection.") # Issue warning if changing the projection after initialization # TODO: Add heuristic size (i.e. ``self.ds.nbytes``) above which the # warning is raised @@ -228,6 +250,57 @@ def latlon(self, value) -> None: self._latlon = value + def _parse_period(self, ds, dim: Literal["x", "y"]): + """ Parse period attribute, return None for non-periodic meshes """ + + attr = f"{dim}_period" + try: + period = float(ds.attrs[attr]) + except KeyError: + period = None + + # in the off chance mesh is periodic but does not have period attribute + if self.is_periodic and attr not in ds.attrs: + raise AttributeError((f"Mesh file: \"{ds.encoding['source']}\"" + f"does not have attribute `{attr}` despite" + f"being a planar periodic mesh.")) + if period == 0.0: + return None + else: + return period + + @property + def x_period(self) -> float | None: + """ Period along x-dimension, is ``None`` for non-periodic meshes """ + return self._x_period + + @x_period.setter + def x_period(self, value) -> None: + # needed to avoid AttributeError for non-periodic meshes + if not (self.is_periodic and self.is_spherical): + self._x_period = None + if not self.is_periodic and self.is_spherical: + x_limits = self.projection.x_limits + self._x_period = np.abs(x_limits[1] - x_limits[0]) + else: + self._x_period = value + + @property + def y_period(self) -> float | None: + """ Period along y-dimension, is ``None`` for non-periodic meshes """ + return self._y_period + + @y_period.setter + def y_period(self, value) -> None: + # needed to avoid AttributeError for non-periodic meshes + if not (self.is_periodic and self.is_spherical): + self._y_period = None + if not self.is_periodic and self.is_spherical: + y_limits = self.projection.y_limits + self._y_period = np.abs(y_limits[1] - y_limits[0]) + else: + self._y_period = value + @cached_property def cell_patches(self) -> ndarray: """:py:class:`~numpy.ndarray` of patch coordinates for cell centered @@ -244,7 +317,13 @@ def cell_patches(self) -> ndarray: patch. Nodes are ordered counter clockwise around the cell center. """ patches = _compute_cell_patches(self.ds) - patches = self._fix_antimeridian(patches, "Cell") + patches = self._wrap_patches(patches, "Cell") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._cell_pole_mask = self._compute_pole_mask("Cell") + return patches @cached_property @@ -263,7 +342,13 @@ def edge_patches(self) -> ndarray: corresponding node will be collapsed to the edge coordinate. """ patches = _compute_edge_patches(self.ds) - patches = self._fix_antimeridian(patches, "Edge") + patches = self._wrap_patches(patches, "Edge") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._edge_pole_mask = self._compute_pole_mask("Edge") + return patches @cached_property @@ -289,7 +374,13 @@ def vertex_patches(self) -> ndarray: position. """ patches = _compute_vertex_patches(self.ds) - patches = self._fix_antimeridian(patches, "Vertex") + patches = self._wrap_patches(patches, "Vertex") + + # cartopy doesn't handle nans in patches, so store a mask of the + # invalid patches to set the dataarray at those locations to nan. + if self.projection: + self._vertex_pole_mask = self._compute_pole_mask("Vertex") + return patches def _transform_coordinates(self, projection, transform): @@ -304,70 +395,74 @@ def _transform_coordinates(self, projection, transform): self.ds[f"x{loc}"].values = transformed_coords[:, 0] self.ds[f"y{loc}"].values = transformed_coords[:, 1] - def _fix_antimeridian(self, patches, loc, projection=None): - """Correct vertices of patches that cross the antimeridian. - - NOTE: Can this be a decorator? + def _wrap_patches(self, patches, loc): + """Wrap patches for spherical and planar-periodic meshes """ - # coordinate arrays are transformed at initalization, so using the - # transform size limit, not the projection - if not projection: - projection = self.projection - - # should be able to come up with a default size limit here, or maybe - # it's already an attribute(?) Should also factor in a precomputed - # axis period, as set in the attributes of the input dataset - if projection: - # convert to numpy array to that broadcasting below will work - x_center = np.array(self.ds[f"x{loc}"]) - - # get distance b/w the center and vertices of the patches - # NOTE: using data from masked patches array so that we compute - # mask only corresponds to patches that cross the boundary, - # (i.e. NOT a mask of all invalid cells). May need to be - # carefull about the fillvalue depending on the transform - half_distance = x_center[:, np.newaxis] - patches[..., 0].data - - # get the size limit of the projection; - size_limit = np.abs(projection.x_limits[1] - - projection.x_limits[0]) / (2 * np.sqrt(2)) - - # left and right mask, with same number of dims as the patches - l_mask = (half_distance > size_limit)[..., np.newaxis] - r_mask = (half_distance < -size_limit)[..., np.newaxis] + def _find_boundary_patches(patches, loc, coord, period): """ - # Old approach masks out all patches that cross the antimeridian. - # This is unnessarily restrictive. New approach corrects - # the x-coordinates of vertices that lie outside the projections - # bounds, which isn't perfect either - - patches.mask |= l_mask - patches.mask |= r_mask + Find the patches that cross the periodic boundary and what + direction they cross the boundary (i.e. their ``sign``). This + method assumes the patch centroids are not periodic """ + # get axis index we are inquiring over + axis = 0 if coord == "x" else 1 + # get requested coordinate of patch centroids + center = self.ds[f"{coord}{loc.title()}"].values.reshape(-1, 1) + # get difference b/w centroid and nodes of patches + diff = patches[..., axis] - center + + mask = np.abs(diff) > np.abs(period) / (2. * np.sqrt(2.)) + sign = np.sign(diff) + + return mask, sign + + def _wrap_1D(patches, mask, sign, axis, period): + """Correct patch periodicity along a single dimension""" + patches[..., axis][mask] -= np.sign(sign[mask]) * period + + # TODO: clip spherical wrapped patches to projection limits + return patches - l_boundary_mask = ~np.any(l_mask, axis=1) | l_mask[..., 0] - r_boundary_mask = ~np.any(r_mask, axis=1) | r_mask[..., 0] - # get valid half distances for the patches that cross boundary - l_offset = np.ma.MaskedArray(half_distance, l_boundary_mask) - r_offset = np.ma.MaskedArray(half_distance, r_boundary_mask) - - # For vertices that cross the antimeridian reset the x-coordinate - # of invalid vertex to be the center of the patch plus the - # mean valid half distance. - # - # NOTE: this only fixes patches on the side of plot where they - # cross the antimeridian, leaving an empty zipper like pattern - # mirrored over the y-axis. - patches[..., 0] = np.ma.where( - ~l_mask[..., 0], patches[..., 0], - x_center[:, np.newaxis] + l_offset.mean(1)[..., np.newaxis]) - patches[..., 0] = np.ma.where( - ~r_mask[..., 0], patches[..., 0], - x_center[:, np.newaxis] + r_offset.mean(1)[..., np.newaxis]) + # Stereographic projections do not need wrapping, so exit early + if isinstance(self.projection, ccrs.Stereographic): + return patches + + if self.x_period: + # find the patch that are periodic in x-direction + x_mask, x_sign = _find_boundary_patches( + patches, loc, "x", self.x_period + ) + + if np.any(x_mask): + # using the sign of the difference correct patches x coordinate + patches = _wrap_1D(patches, x_mask, x_sign, 0, self.x_period) + + if self.y_period: + # find the patch that are periodic in y-direction + y_mask, y_sign = _find_boundary_patches( + patches, loc, "y", self.y_period + ) + + if np.any(x_mask): + # using the sign of the difference correct patches y coordinate + patches = _wrap_1D(patches, y_mask, y_sign, 1, self.y_period) return patches + def _compute_pole_mask(self, loc) -> ndarray: + """ """ + limits = self.projection.y_limits + centers = self.ds[f"y{loc.title()}"].values + + # TODO: determine threshold for ``isclose`` computation + at_pole = np.any( + np.isclose(centers.reshape(-1, 1), limits, rtol=1e-2), axis=1 + ) + past_pole = np.abs(centers) > np.abs(limits[1]) + + return (at_pole | past_pole) + def _compute_cell_patches(ds: Dataset) -> ndarray: """Create cell patches (i.e. Primary cells) for an MPAS mesh.""" diff --git a/mosaic/polypcolor.py b/mosaic/polypcolor.py index d17d9f1..17111e7 100644 --- a/mosaic/polypcolor.py +++ b/mosaic/polypcolor.py @@ -1,3 +1,4 @@ +import numpy as np from cartopy.mpl.geoaxes import GeoAxes from matplotlib.axes import Axes from matplotlib.collections import PolyCollection @@ -48,12 +49,18 @@ def polypcolor( if "nCells" in c.dims: verts = descriptor.cell_patches + if descriptor.projection and np.any(descriptor._cell_pole_mask): + c = c.where(~descriptor._cell_pole_mask, np.nan) elif "nEdges" in c.dims: verts = descriptor.edge_patches + if descriptor.projection and np.any(descriptor._edge_pole_mask): + c = c.where(~descriptor._edge_pole_mask, np.nan) elif "nVertices" in c.dims: verts = descriptor.vertex_patches + if descriptor.projection and np.any(descriptor._vertex_pole_mask): + c = c.where(~descriptor._vertex_pole_mask, np.nan) transform = descriptor.transform