Skip to content

Commit

Permalink
Initial work on minimal reader support
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronzedwick committed Dec 6, 2024
1 parent 41efd5c commit a05fb78
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 121 deletions.
6 changes: 6 additions & 0 deletions test/test_mpas.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,9 @@ def test_face_area(self):

assert "face_areas" in uxgrid_primal._ds
assert "face_areas" in uxgrid_dual._ds

def test_minimal(self):
"""Tests the minimal grid reader"""
uxgrid = ux.open_grid(self.mpas_grid_path, minimal=True)

assert "node_x" not in uxgrid._ds
7 changes: 5 additions & 2 deletions uxarray/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def open_grid(
],
latlon: Optional[bool] = False,
use_dual: Optional[bool] = False,
minimal: Optional[bool] = False,
**kwargs: Dict[str, Any],
) -> Grid:
"""Constructs and returns a ``Grid`` from a grid file.
Expand All @@ -34,11 +35,13 @@ def open_grid(
object to define the grid.
latlon : bool, optional
Specify if the grid is lat/lon based
Specify if the grid is lat/lon based
use_dual: bool, optional
Specify whether to use the primal (use_dual=False) or dual (use_dual=True) mesh if the file type is mpas
minimal: bool, optional
Specify whether to read the minimal information (`nodes` and `face_node_connectivity`) needed for a grid
**kwargs : Dict[str, Any]
Additional arguments passed on to ``xarray.open_dataset``. Refer to the
[xarray
Expand Down Expand Up @@ -84,7 +87,7 @@ def open_grid(
try:
grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs)

uxgrid = Grid.from_dataset(grid_ds, use_dual=use_dual)
uxgrid = Grid.from_dataset(grid_ds, use_dual=use_dual, minimal=minimal)
except ValueError:
raise ValueError("Inputted grid_filename_or_obj not supported.")

Expand Down
24 changes: 17 additions & 7 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ def __init__(

@classmethod
def from_dataset(
cls, dataset: xr.Dataset, use_dual: Optional[bool] = False, **kwargs
cls,
dataset: xr.Dataset,
use_dual: Optional[bool] = False,
minimal: Optional[bool] = False,
**kwargs,
):
"""Constructs a ``Grid`` object from an ``xarray.Dataset``.
Expand All @@ -251,6 +255,8 @@ def from_dataset(
``xarray.Dataset`` containing unstructured grid coordinates and connectivity variables
use_dual : bool, default=False
When reading in MPAS formatted datasets, indicates whether to use the Dual Mesh
minimal : bool, default=False
Specify whether to read the minimal information (`nodes` and `face_node_connectivity`) needed for a grid
"""
if not isinstance(dataset, xr.Dataset):
raise ValueError("Input must be an xarray.Dataset")
Expand All @@ -264,17 +270,21 @@ def from_dataset(
if source_grid_spec == "Exodus":
grid_ds, source_dims_dict = _read_exodus(dataset)
elif source_grid_spec == "Scrip":
grid_ds, source_dims_dict = _read_scrip(dataset)
grid_ds, source_dims_dict = _read_scrip(dataset, minimal)
elif source_grid_spec == "UGRID":
grid_ds, source_dims_dict = _read_ugrid(dataset)
grid_ds, source_dims_dict = _read_ugrid(dataset, minimal=minimal)
elif source_grid_spec == "MPAS":
grid_ds, source_dims_dict = _read_mpas(dataset, use_dual=use_dual)
grid_ds, source_dims_dict = _read_mpas(
dataset, use_dual=use_dual, minimal=minimal
)
elif source_grid_spec == "ESMF":
grid_ds, source_dims_dict = _read_esmf(dataset)
grid_ds, source_dims_dict = _read_esmf(dataset, minimal=minimal)
elif source_grid_spec == "GEOS-CS":
grid_ds, source_dims_dict = _read_geos_cs(dataset)
grid_ds, source_dims_dict = _read_geos_cs(dataset, minimal=minimal)
elif source_grid_spec == "ICON":
grid_ds, source_dims_dict = _read_icon(dataset, use_dual=use_dual)
grid_ds, source_dims_dict = _read_icon(
dataset, use_dual=use_dual, minimal=minimal
)
elif source_grid_spec == "Structured":
grid_ds = _read_structured_grid(dataset[lon_name], dataset[lat_name])
source_dims_dict = {"n_face": (lon_name, lat_name)}
Expand Down
6 changes: 4 additions & 2 deletions uxarray/io/_esmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from uxarray.conventions import ugrid


def _read_esmf(in_ds):
def _read_esmf(in_ds, minimal=False):
"""Reads in an Xarray dataset containing an ESMF formatted Grid dataset and
encodes it in the UGRID conventions.
Expand All @@ -27,6 +27,8 @@ def _read_esmf(in_ds):
----------
in_ds: xr.Dataset
ESMF Grid Dataset
minimal : bool, optional
Specify whether to read the minimal information (`nodes` and `face_node_connectivity`) needed for a grid
Returns
-------
Expand Down Expand Up @@ -56,7 +58,7 @@ def _read_esmf(in_ds):
node_lat, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LAT_ATTRS
)

if "centerCoords" in in_ds:
if "centerCoords" in in_ds and not minimal:
# parse center coords (face centers) if available
face_lon = in_ds["centerCoords"].isel(coordDim=0).values
out_ds[ugrid.FACE_COORDINATES[0]] = xr.DataArray(
Expand Down
15 changes: 13 additions & 2 deletions uxarray/io/_geos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@
from uxarray.conventions import ugrid


def _read_geos_cs(in_ds: xr.Dataset):
def _read_geos_cs(in_ds: xr.Dataset, minimal=False):
"""Reads and encodes a GEOS Cube-Sphere grid into the UGRID conventions.
Parameters
----------
in_ds: xr.Dataset
GEOS_CS Grid Dataset
minimal : bool, optional
Specify whether to read the minimal information (`nodes` and `face_node_connectivity`) needed for a grid
Returns
-------
out_ds: xr.Dataset
GEOS_CS Grid encoder in the UGRID conventions
https://gmao.gsfc.nasa.gov/gmaoftp/ops/GEOSIT_sample/doc/CS_Description_c180_v1.pdf
"""
Expand All @@ -23,7 +34,7 @@ def _read_geos_cs(in_ds: xr.Dataset):
data=node_lat, dims=ugrid.NODE_DIM, attrs=ugrid.NODE_LAT_ATTRS
)

if "lons" in in_ds:
if "lons" in in_ds and not minimal:
face_lon = in_ds["lons"].values.ravel()
face_lat = in_ds["lats"].values.ravel()

Expand Down
125 changes: 69 additions & 56 deletions uxarray/io/_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
import numpy as np


def _primal_to_ugrid(in_ds, out_ds):
"""Encodes the Primal Mesh of an ICON Grid into the UGRID conventions."""
def _primal_to_ugrid(in_ds, out_ds, minimal=False):
"""Encodes the Primal Mesh of an ICON Grid into the UGRID conventions.
Parameters
----------
in_ds: xr.Dataset
ICON Grid Dataset
minimal : bool, optional
Specify whether to read the minimal information (`nodes` and `face_node_connectivity`) needed for a grid
Returns
-------
out_ds: xr.Dataset
ICON Grid encoder in the UGRID conventions
"""
source_dims_dict = {"vertex": "n_node", "edge": "n_edge", "cell": "n_face"}

# rename dimensions to match ugrid conventions
Expand All @@ -21,28 +33,6 @@ def _primal_to_ugrid(in_ds, out_ds):
data=node_lat, dims=ugrid.NODE_DIM, attrs=ugrid.NODE_LAT_ATTRS
)

# edge coordinates
edge_lon = np.rad2deg(in_ds["elon"])
edge_lat = np.rad2deg(in_ds["elat"])

out_ds["edge_lon"] = xr.DataArray(
data=edge_lon, dims=ugrid.EDGE_DIM, attrs=ugrid.EDGE_LON_ATTRS
)
out_ds["edge_lat"] = xr.DataArray(
data=edge_lat, dims=ugrid.EDGE_DIM, attrs=ugrid.EDGE_LAT_ATTRS
)

# face coordinates
face_lon = np.rad2deg(in_ds["clon"])
face_lat = np.rad2deg(in_ds["clat"])

out_ds["face_lon"] = xr.DataArray(
data=face_lon, dims=ugrid.FACE_DIM, attrs=ugrid.FACE_LON_ATTRS
)
out_ds["face_lat"] = xr.DataArray(
data=face_lat, dims=ugrid.FACE_DIM, attrs=ugrid.FACE_LAT_ATTRS
)

face_node_connectivity = in_ds["vertex_of_cell"].T - 1

out_ds["face_node_connectivity"] = xr.DataArray(
Expand All @@ -51,45 +41,68 @@ def _primal_to_ugrid(in_ds, out_ds):
attrs=ugrid.FACE_NODE_CONNECTIVITY_ATTRS,
)

face_edge_connectivity = in_ds["edge_of_cell"].T - 1

out_ds["face_edge_connectivity"] = xr.DataArray(
data=face_edge_connectivity,
dims=ugrid.FACE_EDGE_CONNECTIVITY_DIMS,
attrs=ugrid.FACE_EDGE_CONNECTIVITY_ATTRS,
)

face_face_connectivity = in_ds["neighbor_cell_index"].T - 1

out_ds["face_face_connectivity"] = xr.DataArray(
data=face_face_connectivity,
dims=ugrid.FACE_FACE_CONNECTIVITY_DIMS,
attrs=ugrid.FACE_FACE_CONNECTIVITY_ATTRS,
)

edge_face_connectivity = in_ds["adjacent_cell_of_edge"].T - 1

out_ds["edge_face_connectivity"] = xr.DataArray(
data=edge_face_connectivity,
dims=ugrid.EDGE_FACE_CONNECTIVITY_DIMS,
attrs=ugrid.EDGE_FACE_CONNECTIVITY_ATTRS,
)

edge_node_connectivity = in_ds["edge_vertices"].T - 1
out_ds["edge_node_connectivity"] = xr.DataArray(
data=edge_node_connectivity,
dims=ugrid.EDGE_NODE_CONNECTIVITY_DIMS,
attrs=ugrid.EDGE_NODE_CONNECTIVITY_ATTRS,
)
if not minimal:
# edge coordinates
edge_lon = np.rad2deg(in_ds["elon"])
edge_lat = np.rad2deg(in_ds["elat"])

out_ds["edge_lon"] = xr.DataArray(
data=edge_lon, dims=ugrid.EDGE_DIM, attrs=ugrid.EDGE_LON_ATTRS
)
out_ds["edge_lat"] = xr.DataArray(
data=edge_lat, dims=ugrid.EDGE_DIM, attrs=ugrid.EDGE_LAT_ATTRS
)

# face coordinates
face_lon = np.rad2deg(in_ds["clon"])
face_lat = np.rad2deg(in_ds["clat"])

out_ds["face_lon"] = xr.DataArray(
data=face_lon, dims=ugrid.FACE_DIM, attrs=ugrid.FACE_LON_ATTRS
)
out_ds["face_lat"] = xr.DataArray(
data=face_lat, dims=ugrid.FACE_DIM, attrs=ugrid.FACE_LAT_ATTRS
)

face_edge_connectivity = in_ds["edge_of_cell"].T - 1

out_ds["face_edge_connectivity"] = xr.DataArray(
data=face_edge_connectivity,
dims=ugrid.FACE_EDGE_CONNECTIVITY_DIMS,
attrs=ugrid.FACE_EDGE_CONNECTIVITY_ATTRS,
)

face_face_connectivity = in_ds["neighbor_cell_index"].T - 1

out_ds["face_face_connectivity"] = xr.DataArray(
data=face_face_connectivity,
dims=ugrid.FACE_FACE_CONNECTIVITY_DIMS,
attrs=ugrid.FACE_FACE_CONNECTIVITY_ATTRS,
)

edge_face_connectivity = in_ds["adjacent_cell_of_edge"].T - 1

out_ds["edge_face_connectivity"] = xr.DataArray(
data=edge_face_connectivity,
dims=ugrid.EDGE_FACE_CONNECTIVITY_DIMS,
attrs=ugrid.EDGE_FACE_CONNECTIVITY_ATTRS,
)

edge_node_connectivity = in_ds["edge_vertices"].T - 1
out_ds["edge_node_connectivity"] = xr.DataArray(
data=edge_node_connectivity,
dims=ugrid.EDGE_NODE_CONNECTIVITY_DIMS,
attrs=ugrid.EDGE_NODE_CONNECTIVITY_ATTRS,
)

return out_ds, source_dims_dict


def _read_icon(ext_ds, use_dual=False):
def _read_icon(ext_ds, use_dual=False, minimal=False):
"""Reads and encodes an ICON mesh into the UGRID conventions."""
out_ds = xr.Dataset()

if not use_dual:
return _primal_to_ugrid(ext_ds, out_ds)
return _primal_to_ugrid(ext_ds, out_ds, minimal)
else:
raise ValueError("Conversion of the ICON Dual mesh is not yet supported.")
Loading

0 comments on commit a05fb78

Please sign in to comment.